diff --git a/crates/wasmi/benches/wat/count_until.wat b/crates/wasmi/benches/wat/count_until.wat index 46c4c0ed4b..bd13f3c291 100644 --- a/crates/wasmi/benches/wat/count_until.wat +++ b/crates/wasmi/benches/wat/count_until.wat @@ -3,21 +3,18 @@ (module (func (export "count_until") (param $limit i32) (result i32) (local $counter i32) - (block - (loop - (br_if - 1 - (i32.eq - (local.tee $counter - (i32.add - (local.get $counter) - (i32.const 1) - ) + (loop + (br_if + 0 + (i32.ne + (local.tee $counter + (i32.add + (local.get $counter) + (i32.const 1) ) - (local.get $limit) ) + (local.get $limit) ) - (br 0) ) ) (return (local.get $counter)) diff --git a/crates/wasmi/src/engine/bytecode/utils.rs b/crates/wasmi/src/engine/bytecode/utils.rs index adbcc077f1..34c66b0af7 100644 --- a/crates/wasmi/src/engine/bytecode/utils.rs +++ b/crates/wasmi/src/engine/bytecode/utils.rs @@ -273,7 +273,6 @@ impl AddressOffset { #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct BranchOffset(i32); -#[cfg(test)] impl From for BranchOffset { fn from(index: i32) -> Self { Self(index) diff --git a/crates/wasmi/src/engine/regmach/bytecode/construct.rs b/crates/wasmi/src/engine/regmach/bytecode/construct.rs index 68f2bf3446..5ba856b30f 100644 --- a/crates/wasmi/src/engine/regmach/bytecode/construct.rs +++ b/crates/wasmi/src/engine/regmach/bytecode/construct.rs @@ -1,8 +1,10 @@ use super::{ - utils::{CopysignImmInstr, Sign}, + utils::{BranchOffset16, CopysignImmInstr, Sign}, AnyConst32, BinInstr, BinInstrImm16, + BranchBinOpInstr, + BranchBinOpInstrImm, CallIndirectParams, Const16, Const32, @@ -160,6 +162,92 @@ macro_rules! constructor_for { }; } +macro_rules! constructor_for_branch_binop { + ( $( fn $name:ident() -> Self::$op_code:ident; )* ) => { + impl Instruction { + $( + #[doc = concat!("Creates a new [`Instruction::", stringify!($op_code), "`].")] + pub fn $name(lhs: Register, rhs: Register, offset: BranchOffset16) -> Self { + Self::$op_code(BranchBinOpInstr::new(lhs, rhs, offset)) + } + )* + } + } +} +constructor_for_branch_binop! { + fn branch_i32_eq() -> Self::BranchI32Eq; + fn branch_i32_ne() -> Self::BranchI32Ne; + fn branch_i32_lt_s() -> Self::BranchI32LtS; + fn branch_i32_lt_u() -> Self::BranchI32LtU; + fn branch_i32_le_s() -> Self::BranchI32LeS; + fn branch_i32_le_u() -> Self::BranchI32LeU; + fn branch_i32_gt_s() -> Self::BranchI32GtS; + fn branch_i32_gt_u() -> Self::BranchI32GtU; + fn branch_i32_ge_s() -> Self::BranchI32GeS; + fn branch_i32_ge_u() -> Self::BranchI32GeU; + + fn branch_i64_eq() -> Self::BranchI64Eq; + fn branch_i64_ne() -> Self::BranchI64Ne; + fn branch_i64_lt_s() -> Self::BranchI64LtS; + fn branch_i64_lt_u() -> Self::BranchI64LtU; + fn branch_i64_le_s() -> Self::BranchI64LeS; + fn branch_i64_le_u() -> Self::BranchI64LeU; + fn branch_i64_gt_s() -> Self::BranchI64GtS; + fn branch_i64_gt_u() -> Self::BranchI64GtU; + fn branch_i64_ge_s() -> Self::BranchI64GeS; + fn branch_i64_ge_u() -> Self::BranchI64GeU; + + fn branch_f32_eq() -> Self::BranchF32Eq; + fn branch_f32_ne() -> Self::BranchF32Ne; + fn branch_f32_lt() -> Self::BranchF32Lt; + fn branch_f32_le() -> Self::BranchF32Le; + fn branch_f32_gt() -> Self::BranchF32Gt; + fn branch_f32_ge() -> Self::BranchF32Ge; + + fn branch_f64_eq() -> Self::BranchF64Eq; + fn branch_f64_ne() -> Self::BranchF64Ne; + fn branch_f64_lt() -> Self::BranchF64Lt; + fn branch_f64_le() -> Self::BranchF64Le; + fn branch_f64_gt() -> Self::BranchF64Gt; + fn branch_f64_ge() -> Self::BranchF64Ge; +} + +macro_rules! constructor_for_branch_binop_imm { + ( $( fn $name:ident($ty:ty) -> Self::$op_code:ident; )* ) => { + impl Instruction { + $( + #[doc = concat!("Creates a new [`Instruction::", stringify!($op_code), "`].")] + pub fn $name(lhs: Register, rhs: Const16<$ty>, offset: BranchOffset16) -> Self { + Self::$op_code(BranchBinOpInstrImm::new(lhs, rhs, offset)) + } + )* + } + } +} +constructor_for_branch_binop_imm! { + fn branch_i32_eq_imm(i32) -> Self::BranchI32EqImm; + fn branch_i32_ne_imm(i32) -> Self::BranchI32NeImm; + fn branch_i32_lt_s_imm(i32) -> Self::BranchI32LtSImm; + fn branch_i32_lt_u_imm(u32) -> Self::BranchI32LtUImm; + fn branch_i32_le_s_imm(i32) -> Self::BranchI32LeSImm; + fn branch_i32_le_u_imm(u32) -> Self::BranchI32LeUImm; + fn branch_i32_gt_s_imm(i32) -> Self::BranchI32GtSImm; + fn branch_i32_gt_u_imm(u32) -> Self::BranchI32GtUImm; + fn branch_i32_ge_s_imm(i32) -> Self::BranchI32GeSImm; + fn branch_i32_ge_u_imm(u32) -> Self::BranchI32GeUImm; + + fn branch_i64_eq_imm(i64) -> Self::BranchI64EqImm; + fn branch_i64_ne_imm(i64) -> Self::BranchI64NeImm; + fn branch_i64_lt_s_imm(i64) -> Self::BranchI64LtSImm; + fn branch_i64_lt_u_imm(u64) -> Self::BranchI64LtUImm; + fn branch_i64_le_s_imm(i64) -> Self::BranchI64LeSImm; + fn branch_i64_le_u_imm(u64) -> Self::BranchI64LeUImm; + fn branch_i64_gt_s_imm(i64) -> Self::BranchI64GtSImm; + fn branch_i64_gt_u_imm(u64) -> Self::BranchI64GtUImm; + fn branch_i64_ge_s_imm(i64) -> Self::BranchI64GeSImm; + fn branch_i64_ge_u_imm(u64) -> Self::BranchI64GeUImm; +} + impl Instruction { /// Creates a new [`Instruction::Const32`] from the given `value`. pub fn const32(value: impl Into) -> Self { diff --git a/crates/wasmi/src/engine/regmach/bytecode/immediate.rs b/crates/wasmi/src/engine/regmach/bytecode/immediate.rs index 1f490b54a5..06dbf40c08 100644 --- a/crates/wasmi/src/engine/regmach/bytecode/immediate.rs +++ b/crates/wasmi/src/engine/regmach/bytecode/immediate.rs @@ -14,6 +14,20 @@ pub struct Const16 { marker: PhantomData T>, } +impl Const16 { + /// Returns `true` if the [`Const16`]`` is equal to zero. + pub fn is_zero(&self) -> bool { + self.inner == AnyConst16::from(0_i16) + } +} + +impl Const16 { + /// Returns `true` if the [`Const16`]`` is equal to zero. + pub fn is_zero(&self) -> bool { + self.inner == AnyConst16::from(0_i16) + } +} + impl Const16 { /// Crete a new typed [`Const16`] value. pub fn new(inner: AnyConst16) -> Self { diff --git a/crates/wasmi/src/engine/regmach/bytecode/mod.rs b/crates/wasmi/src/engine/regmach/bytecode/mod.rs index 82b064aee4..8ea97232b7 100644 --- a/crates/wasmi/src/engine/regmach/bytecode/mod.rs +++ b/crates/wasmi/src/engine/regmach/bytecode/mod.rs @@ -12,6 +12,9 @@ pub(crate) use self::{ utils::{ BinInstr, BinInstrImm16, + BranchBinOpInstr, + BranchBinOpInstrImm, + BranchOffset16, CallIndirectParams, CopysignImmInstr, LoadAtInstr, @@ -387,6 +390,198 @@ pub enum Instruction { offset: BranchOffset, }, + /// A fused [`Instruction::I32Eq`] and [`Instruction::BranchNez`] instruction. + BranchI32Eq(BranchBinOpInstr), + /// A fused [`Instruction::I32Eq`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32Eq`] with 16-bit encoded constant `rhs`. + BranchI32EqImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32Ne`] and [`Instruction::BranchNez`] instruction. + BranchI32Ne(BranchBinOpInstr), + /// A fused [`Instruction::I32Ne`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32Ne`] with 16-bit encoded constant `rhs`. + BranchI32NeImm(BranchBinOpInstrImm), + + /// A fused [`Instruction::I32LtS`] and [`Instruction::BranchNez`] instruction. + BranchI32LtS(BranchBinOpInstr), + /// A fused [`Instruction::I32LtS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32LtS`] with 16-bit encoded constant `rhs`. + BranchI32LtSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32LtU`] and [`Instruction::BranchNez`] instruction. + BranchI32LtU(BranchBinOpInstr), + /// A fused [`Instruction::I32LtU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32LtU`] with 16-bit encoded constant `rhs`. + BranchI32LtUImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32LeS`] and [`Instruction::BranchNez`] instruction. + BranchI32LeS(BranchBinOpInstr), + /// A fused [`Instruction::I32LeS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32LeS`] with 16-bit encoded constant `rhs`. + BranchI32LeSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32LeU`] and [`Instruction::BranchNez`] instruction. + BranchI32LeU(BranchBinOpInstr), + /// A fused [`Instruction::I32LeU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32LeU`] with 16-bit encoded constant `rhs`. + BranchI32LeUImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32GtS`] and [`Instruction::BranchNez`] instruction. + BranchI32GtS(BranchBinOpInstr), + /// A fused [`Instruction::I32GtS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32GtS`] with 16-bit encoded constant `rhs`. + BranchI32GtSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32GtU`] and [`Instruction::BranchNez`] instruction. + BranchI32GtU(BranchBinOpInstr), + /// A fused [`Instruction::I32GtU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32GtU`] with 16-bit encoded constant `rhs`. + BranchI32GtUImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32GeS`] and [`Instruction::BranchNez`] instruction. + BranchI32GeS(BranchBinOpInstr), + /// A fused [`Instruction::I32GeS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32GeS`] with 16-bit encoded constant `rhs`. + BranchI32GeSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I32GeU`] and [`Instruction::BranchNez`] instruction. + BranchI32GeU(BranchBinOpInstr), + /// A fused [`Instruction::I32GeU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI32GeU`] with 16-bit encoded constant `rhs`. + BranchI32GeUImm(BranchBinOpInstrImm), + + /// A fused [`Instruction::I64Eq`] and [`Instruction::BranchNez`] instruction. + BranchI64Eq(BranchBinOpInstr), + /// A fused [`Instruction::I64Eq`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64Eq`] with 16-bit encoded constant `rhs`. + BranchI64EqImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64Ne`] and [`Instruction::BranchNez`] instruction. + BranchI64Ne(BranchBinOpInstr), + /// A fused [`Instruction::I64Ne`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64Ne`] with 16-bit encoded constant `rhs`. + BranchI64NeImm(BranchBinOpInstrImm), + + /// A fused [`Instruction::I64LtS`] and [`Instruction::BranchNez`] instruction. + BranchI64LtS(BranchBinOpInstr), + /// A fused [`Instruction::I64LtS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64LtS`] with 16-bit encoded constant `rhs`. + BranchI64LtSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64LtU`] and [`Instruction::BranchNez`] instruction. + BranchI64LtU(BranchBinOpInstr), + /// A fused [`Instruction::I64LtU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64LtU`] with 16-bit encoded constant `rhs`. + BranchI64LtUImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64LeS`] and [`Instruction::BranchNez`] instruction. + BranchI64LeS(BranchBinOpInstr), + /// A fused [`Instruction::I64LeS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64LeS`] with 16-bit encoded constant `rhs`. + BranchI64LeSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64LeU`] and [`Instruction::BranchNez`] instruction. + BranchI64LeU(BranchBinOpInstr), + /// A fused [`Instruction::I64LeU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64LeU`] with 16-bit encoded constant `rhs`. + BranchI64LeUImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64GtS`] and [`Instruction::BranchNez`] instruction. + BranchI64GtS(BranchBinOpInstr), + /// A fused [`Instruction::I64GtS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64GtS`] with 16-bit encoded constant `rhs`. + BranchI64GtSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64GtU`] and [`Instruction::BranchNez`] instruction. + BranchI64GtU(BranchBinOpInstr), + /// A fused [`Instruction::I64GtU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64GtU`] with 16-bit encoded constant `rhs`. + BranchI64GtUImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64GeS`] and [`Instruction::BranchNez`] instruction. + BranchI64GeS(BranchBinOpInstr), + /// A fused [`Instruction::I64GeS`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64GeS`] with 16-bit encoded constant `rhs`. + BranchI64GeSImm(BranchBinOpInstrImm), + /// A fused [`Instruction::I64GeU`] and [`Instruction::BranchNez`] instruction. + BranchI64GeU(BranchBinOpInstr), + /// A fused [`Instruction::I64GeU`] and [`Instruction::BranchNez`] instruction. + /// + /// # Note + /// + /// Variant of [`Instruction::BranchI64GeU`] with 16-bit encoded constant `rhs`. + BranchI64GeUImm(BranchBinOpInstrImm), + + /// A fused [`Instruction::F32Eq`] and [`Instruction::BranchNez`] instruction. + BranchF32Eq(BranchBinOpInstr), + /// A fused [`Instruction::F32Ne`] and [`Instruction::BranchNez`] instruction. + BranchF32Ne(BranchBinOpInstr), + + /// A fused [`Instruction::F32Lt`] and [`Instruction::BranchNez`] instruction. + BranchF32Lt(BranchBinOpInstr), + /// A fused [`Instruction::F32Le`] and [`Instruction::BranchNez`] instruction. + BranchF32Le(BranchBinOpInstr), + /// A fused [`Instruction::F32Gt`] and [`Instruction::BranchNez`] instruction. + BranchF32Gt(BranchBinOpInstr), + /// A fused [`Instruction::F32Ge`] and [`Instruction::BranchNez`] instruction. + BranchF32Ge(BranchBinOpInstr), + + /// A fused [`Instruction::F64Eq`] and [`Instruction::BranchNez`] instruction. + BranchF64Eq(BranchBinOpInstr), + /// A fused [`Instruction::F64Ne`] and [`Instruction::BranchNez`] instruction. + BranchF64Ne(BranchBinOpInstr), + + /// A fused [`Instruction::F64Lt`] and [`Instruction::BranchNez`] instruction. + BranchF64Lt(BranchBinOpInstr), + /// A fused [`Instruction::F64Le`] and [`Instruction::BranchNez`] instruction. + BranchF64Le(BranchBinOpInstr), + /// A fused [`Instruction::F64Gt`] and [`Instruction::BranchNez`] instruction. + BranchF64Gt(BranchBinOpInstr), + /// A fused [`Instruction::F64Ge`] and [`Instruction::BranchNez`] instruction. + BranchF64Ge(BranchBinOpInstr), + /// A Wasm `br_table` instruction. /// /// # Encoding diff --git a/crates/wasmi/src/engine/regmach/bytecode/utils.rs b/crates/wasmi/src/engine/regmach/bytecode/utils.rs index e4267e5d87..40ffc642fc 100644 --- a/crates/wasmi/src/engine/regmach/bytecode/utils.rs +++ b/crates/wasmi/src/engine/regmach/bytecode/utils.rs @@ -1,5 +1,9 @@ use super::{Const16, Const32}; -use crate::engine::{bytecode::TableIdx, func_builder::TranslationErrorInner, TranslationError}; +use crate::engine::{ + bytecode::{BranchOffset, TableIdx}, + func_builder::TranslationErrorInner, + TranslationError, +}; #[cfg(doc)] use super::Instruction; @@ -462,3 +466,95 @@ pub struct CallIndirectParams { /// The index of the called function in the table. pub index: T, } + +/// A 16-bit signed offset for branch instructions. +/// +/// This defines how much the instruction pointer is offset +/// upon taking the respective branch. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct BranchOffset16(i16); + +#[cfg(test)] +impl From for BranchOffset16 { + fn from(offset: i16) -> Self { + Self(offset) + } +} + +impl BranchOffset16 { + /// Creates a 16-bit [`BranchOffset16`] from a 32-bit [`BranchOffset`] if possible. + pub fn new(offset: BranchOffset) -> Option { + let offset16 = i16::try_from(offset.to_i32()).ok()?; + Some(Self(offset16)) + } + + /// Returns `true` if the [`BranchOffset16`] has been initialized. + pub fn is_init(self) -> bool { + self.to_i16() != 0 + } + + /// Initializes the [`BranchOffset`] with a proper value. + /// + /// # Panics + /// + /// - If the [`BranchOffset`] have already been initialized. + /// - If the given [`BranchOffset`] is not properly initialized. + pub fn init(&mut self, valid_offset: BranchOffset) -> Result<(), TranslationError> { + assert!(valid_offset.is_init()); + assert!(!self.is_init()); + let Some(valid_offset16) = Self::new(valid_offset) else { + return Err(TranslationError::new( + TranslationErrorInner::BranchOffsetOutOfBounds, + )); + }; + *self = valid_offset16; + Ok(()) + } + + /// Returns the `i16` representation of the [`BranchOffset`]. + pub fn to_i16(self) -> i16 { + self.0 + } +} + +impl From for BranchOffset { + fn from(offset: BranchOffset16) -> Self { + Self::from(i32::from(offset.to_i16())) + } +} + +/// A generic fused comparison and conditional branch [`Instruction`]. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct BranchBinOpInstr { + /// The left-hand side operand to the conditional operator. + pub lhs: Register, + /// The right-hand side operand to the conditional operator. + pub rhs: Register, + /// The 16-bit encoded branch offset. + pub offset: BranchOffset16, +} + +impl BranchBinOpInstr { + /// Creates a new [`BranchBinOpInstr`]. + pub fn new(lhs: Register, rhs: Register, offset: BranchOffset16) -> Self { + Self { lhs, rhs, offset } + } +} + +/// A generic fused comparison and conditional branch [`Instruction`]. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct BranchBinOpInstrImm { + /// The left-hand side operand to the conditional operator. + pub lhs: Register, + /// The right-hand side operand to the conditional operator. + pub rhs: Const16, + /// The 16-bit encoded branch offset. + pub offset: BranchOffset16, +} + +impl BranchBinOpInstrImm { + /// Creates a new [`BranchBinOpInstr`]. + pub fn new(lhs: Register, rhs: Const16, offset: BranchOffset16) -> Self { + Self { lhs, rhs, offset } + } +} diff --git a/crates/wasmi/src/engine/regmach/executor/instrs.rs b/crates/wasmi/src/engine/regmach/executor/instrs.rs index 52926acdab..eec6196df6 100644 --- a/crates/wasmi/src/engine/regmach/executor/instrs.rs +++ b/crates/wasmi/src/engine/regmach/executor/instrs.rs @@ -12,7 +12,6 @@ use crate::{ BinInstr, BinInstrImm16, Const16, - Const32, Instruction, Register, RegisterSpan, @@ -28,9 +27,9 @@ use crate::{ FuncRef, StoreInner, }; -use core::cmp; mod binary; +mod branch; mod call; mod comparison; mod conversion; @@ -255,6 +254,58 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { Instr::BranchTable { index, len_targets } => { self.execute_branch_table(index, len_targets) } + Instr::BranchI32Eq(instr) => self.execute_branch_i32_eq(instr), + Instr::BranchI32EqImm(instr) => self.execute_branch_i32_eq_imm(instr), + Instr::BranchI32Ne(instr) => self.execute_branch_i32_ne(instr), + Instr::BranchI32NeImm(instr) => self.execute_branch_i32_ne_imm(instr), + Instr::BranchI32LtS(instr) => self.execute_branch_i32_lt_s(instr), + Instr::BranchI32LtSImm(instr) => self.execute_branch_i32_lt_s_imm(instr), + Instr::BranchI32LtU(instr) => self.execute_branch_i32_lt_u(instr), + Instr::BranchI32LtUImm(instr) => self.execute_branch_i32_lt_u_imm(instr), + Instr::BranchI32LeS(instr) => self.execute_branch_i32_le_s(instr), + Instr::BranchI32LeSImm(instr) => self.execute_branch_i32_le_s_imm(instr), + Instr::BranchI32LeU(instr) => self.execute_branch_i32_le_u(instr), + Instr::BranchI32LeUImm(instr) => self.execute_branch_i32_le_u_imm(instr), + Instr::BranchI32GtS(instr) => self.execute_branch_i32_gt_s(instr), + Instr::BranchI32GtSImm(instr) => self.execute_branch_i32_gt_s_imm(instr), + Instr::BranchI32GtU(instr) => self.execute_branch_i32_gt_u(instr), + Instr::BranchI32GtUImm(instr) => self.execute_branch_i32_gt_u_imm(instr), + Instr::BranchI32GeS(instr) => self.execute_branch_i32_ge_s(instr), + Instr::BranchI32GeSImm(instr) => self.execute_branch_i32_ge_s_imm(instr), + Instr::BranchI32GeU(instr) => self.execute_branch_i32_ge_u(instr), + Instr::BranchI32GeUImm(instr) => self.execute_branch_i32_ge_u_imm(instr), + Instr::BranchI64Eq(instr) => self.execute_branch_i64_eq(instr), + Instr::BranchI64EqImm(instr) => self.execute_branch_i64_eq_imm(instr), + Instr::BranchI64Ne(instr) => self.execute_branch_i64_ne(instr), + Instr::BranchI64NeImm(instr) => self.execute_branch_i64_ne_imm(instr), + Instr::BranchI64LtS(instr) => self.execute_branch_i64_lt_s(instr), + Instr::BranchI64LtSImm(instr) => self.execute_branch_i64_lt_s_imm(instr), + Instr::BranchI64LtU(instr) => self.execute_branch_i64_lt_u(instr), + Instr::BranchI64LtUImm(instr) => self.execute_branch_i64_lt_u_imm(instr), + Instr::BranchI64LeS(instr) => self.execute_branch_i64_le_s(instr), + Instr::BranchI64LeSImm(instr) => self.execute_branch_i64_le_s_imm(instr), + Instr::BranchI64LeU(instr) => self.execute_branch_i64_le_u(instr), + Instr::BranchI64LeUImm(instr) => self.execute_branch_i64_le_u_imm(instr), + Instr::BranchI64GtS(instr) => self.execute_branch_i64_gt_s(instr), + Instr::BranchI64GtSImm(instr) => self.execute_branch_i64_gt_s_imm(instr), + Instr::BranchI64GtU(instr) => self.execute_branch_i64_gt_u(instr), + Instr::BranchI64GtUImm(instr) => self.execute_branch_i64_gt_u_imm(instr), + Instr::BranchI64GeS(instr) => self.execute_branch_i64_ge_s(instr), + Instr::BranchI64GeSImm(instr) => self.execute_branch_i64_ge_s_imm(instr), + Instr::BranchI64GeU(instr) => self.execute_branch_i64_ge_u(instr), + Instr::BranchI64GeUImm(instr) => self.execute_branch_i64_ge_u_imm(instr), + Instr::BranchF32Eq(instr) => self.execute_branch_f32_eq(instr), + Instr::BranchF32Ne(instr) => self.execute_branch_f32_ne(instr), + Instr::BranchF32Lt(instr) => self.execute_branch_f32_lt(instr), + Instr::BranchF32Le(instr) => self.execute_branch_f32_le(instr), + Instr::BranchF32Gt(instr) => self.execute_branch_f32_gt(instr), + Instr::BranchF32Ge(instr) => self.execute_branch_f32_ge(instr), + Instr::BranchF64Eq(instr) => self.execute_branch_f64_eq(instr), + Instr::BranchF64Ne(instr) => self.execute_branch_f64_ne(instr), + Instr::BranchF64Lt(instr) => self.execute_branch_f64_lt(instr), + Instr::BranchF64Le(instr) => self.execute_branch_f64_le(instr), + Instr::BranchF64Gt(instr) => self.execute_branch_f64_gt(instr), + Instr::BranchF64Ge(instr) => self.execute_branch_f64_ge(instr), Instr::Copy { result, value } => self.execute_copy(result, value), Instr::Copy2 { results, values } => self.execute_copy_2(results, values), Instr::CopyImm32 { result, value } => self.execute_copy_imm32(result, value), @@ -1172,49 +1223,6 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { self.try_next_instr() } - #[inline(always)] - fn execute_branch(&mut self, offset: BranchOffset) { - self.branch_to(offset) - } - - #[inline(always)] - fn execute_branch_nez(&mut self, condition: Register, offset: BranchOffset) { - let condition: bool = self.get_register_as(condition); - match condition { - true => { - self.branch_to(offset); - } - false => { - self.next_instr(); - } - } - } - - #[inline(always)] - fn execute_branch_eqz(&mut self, condition: Register, offset: BranchOffset) { - let condition: bool = self.get_register_as(condition); - match condition { - true => { - self.next_instr(); - } - false => { - self.branch_to(offset); - } - } - } - - #[inline(always)] - fn execute_branch_table(&mut self, index: Register, len_targets: Const32) { - // Safety: TODO - let index: u32 = self.get_register_as(index); - // The index of the default target which is the last target of the slice. - let max_index = u32::from(len_targets) - 1; - // A normalized index will always yield a target without panicking. - let normalized_index = cmp::min(index, max_index); - // Update `pc`: - self.ip.add(normalized_index as usize + 1); - } - /// Executes an [`Instruction::RefFunc`]. #[inline(always)] fn execute_ref_func(&mut self, result: Register, func_index: FuncIdx) { diff --git a/crates/wasmi/src/engine/regmach/executor/instrs/branch.rs b/crates/wasmi/src/engine/regmach/executor/instrs/branch.rs new file mode 100644 index 0000000000..039b0cbab5 --- /dev/null +++ b/crates/wasmi/src/engine/regmach/executor/instrs/branch.rs @@ -0,0 +1,176 @@ +use wasmi_core::UntypedValue; + +use super::Executor; +use crate::engine::{ + bytecode::BranchOffset, + regmach::bytecode::{BranchBinOpInstr, BranchBinOpInstrImm, Const16, Const32, Register}, +}; +use core::cmp; + +#[cfg(doc)] +use crate::engine::regmach::bytecode::Instruction; + +impl<'ctx, 'engine> Executor<'ctx, 'engine> { + #[inline(always)] + pub fn execute_branch(&mut self, offset: BranchOffset) { + self.branch_to(offset) + } + + #[inline(always)] + pub fn execute_branch_nez(&mut self, condition: Register, offset: BranchOffset) { + let condition: bool = self.get_register_as(condition); + match condition { + true => { + self.branch_to(offset); + } + false => { + self.next_instr(); + } + } + } + + #[inline(always)] + pub fn execute_branch_eqz(&mut self, condition: Register, offset: BranchOffset) { + let condition: bool = self.get_register_as(condition); + match condition { + true => { + self.next_instr(); + } + false => { + self.branch_to(offset); + } + } + } + + #[inline(always)] + pub fn execute_branch_table(&mut self, index: Register, len_targets: Const32) { + let index: u32 = self.get_register_as(index); + // The index of the default target which is the last target of the slice. + let max_index = u32::from(len_targets) - 1; + // A normalized index will always yield a target without panicking. + let normalized_index = cmp::min(index, max_index); + // Update `pc`: + self.ip.add(normalized_index as usize + 1); + } + + /// Executes a generic fused compare and branch instruction. + fn execute_branch_binop( + &mut self, + instr: BranchBinOpInstr, + f: fn(UntypedValue, UntypedValue) -> UntypedValue, + ) { + let lhs = self.get_register(instr.lhs); + let rhs = self.get_register(instr.rhs); + if bool::from(f(lhs, rhs)) { + self.branch_to(instr.offset.into()); + } else { + self.next_instr() + } + } + + /// Executes a generic fused compare and branch instruction with immediate `rhs` operand. + fn execute_branch_binop_imm( + &mut self, + instr: BranchBinOpInstrImm, + f: fn(UntypedValue, UntypedValue) -> UntypedValue, + ) where + T: From>, + UntypedValue: From, + { + let lhs = self.get_register(instr.lhs); + let rhs = UntypedValue::from(T::from(instr.rhs)); + if bool::from(f(lhs, rhs)) { + self.branch_to(instr.offset.into()); + } else { + self.next_instr() + } + } +} + +macro_rules! impl_execute_branch_binop { + ( $( (Instruction::$op_name:ident, $fn_name:ident, $op:expr) ),* $(,)? ) => { + impl<'ctx, 'engine> Executor<'ctx, 'engine> { + $( + #[doc = concat!("Executes an [`Instruction::", stringify!($op_name), "`].")] + #[inline(always)] + pub fn $fn_name(&mut self, instr: BranchBinOpInstr) { + self.execute_branch_binop(instr, $op) + } + )* + } + } +} +impl_execute_branch_binop! { + (Instruction::BranchI32Eq, execute_branch_i32_eq, UntypedValue::i32_eq), + (Instruction::BranchI32Ne, execute_branch_i32_ne, UntypedValue::i32_ne), + (Instruction::BranchI32LtS, execute_branch_i32_lt_s, UntypedValue::i32_lt_s), + (Instruction::BranchI32LtU, execute_branch_i32_lt_u, UntypedValue::i32_lt_u), + (Instruction::BranchI32LeS, execute_branch_i32_le_s, UntypedValue::i32_le_s), + (Instruction::BranchI32LeU, execute_branch_i32_le_u, UntypedValue::i32_le_u), + (Instruction::BranchI32GtS, execute_branch_i32_gt_s, UntypedValue::i32_gt_s), + (Instruction::BranchI32GtU, execute_branch_i32_gt_u, UntypedValue::i32_gt_u), + (Instruction::BranchI32GeS, execute_branch_i32_ge_s, UntypedValue::i32_ge_s), + (Instruction::BranchI32GeU, execute_branch_i32_ge_u, UntypedValue::i32_ge_u), + + (Instruction::BranchI64Eq, execute_branch_i64_eq, UntypedValue::i64_eq), + (Instruction::BranchI64Ne, execute_branch_i64_ne, UntypedValue::i64_ne), + (Instruction::BranchI64LtS, execute_branch_i64_lt_s, UntypedValue::i64_lt_s), + (Instruction::BranchI64LtU, execute_branch_i64_lt_u, UntypedValue::i64_lt_u), + (Instruction::BranchI64LeS, execute_branch_i64_le_s, UntypedValue::i64_le_s), + (Instruction::BranchI64LeU, execute_branch_i64_le_u, UntypedValue::i64_le_u), + (Instruction::BranchI64GtS, execute_branch_i64_gt_s, UntypedValue::i64_gt_s), + (Instruction::BranchI64GtU, execute_branch_i64_gt_u, UntypedValue::i64_gt_u), + (Instruction::BranchI64GeS, execute_branch_i64_ge_s, UntypedValue::i64_ge_s), + (Instruction::BranchI64GeU, execute_branch_i64_ge_u, UntypedValue::i64_ge_u), + + (Instruction::BranchF32Eq, execute_branch_f32_eq, UntypedValue::f32_eq), + (Instruction::BranchF32Ne, execute_branch_f32_ne, UntypedValue::f32_ne), + (Instruction::BranchF32Lt, execute_branch_f32_lt, UntypedValue::f32_lt), + (Instruction::BranchF32Le, execute_branch_f32_le, UntypedValue::f32_le), + (Instruction::BranchF32Gt, execute_branch_f32_gt, UntypedValue::f32_gt), + (Instruction::BranchF32Ge, execute_branch_f32_ge, UntypedValue::f32_ge), + + (Instruction::BranchF64Eq, execute_branch_f64_eq, UntypedValue::f64_eq), + (Instruction::BranchF64Ne, execute_branch_f64_ne, UntypedValue::f64_ne), + (Instruction::BranchF64Lt, execute_branch_f64_lt, UntypedValue::f64_lt), + (Instruction::BranchF64Le, execute_branch_f64_le, UntypedValue::f64_le), + (Instruction::BranchF64Gt, execute_branch_f64_gt, UntypedValue::f64_gt), + (Instruction::BranchF64Ge, execute_branch_f64_ge, UntypedValue::f64_ge), +} + +macro_rules! impl_execute_branch_binop_imm { + ( $( (Instruction::$op_name:ident, $fn_name:ident, $op:expr, $ty:ty) ),* $(,)? ) => { + impl<'ctx, 'engine> Executor<'ctx, 'engine> { + $( + #[doc = concat!("Executes an [`Instruction::", stringify!($op_name), "`].")] + #[inline(always)] + pub fn $fn_name(&mut self, instr: BranchBinOpInstrImm<$ty>) { + self.execute_branch_binop_imm(instr, $op) + } + )* + } + } +} +impl_execute_branch_binop_imm! { + (Instruction::BranchI32EqImm, execute_branch_i32_eq_imm, UntypedValue::i32_eq, i32), + (Instruction::BranchI32NeImm, execute_branch_i32_ne_imm, UntypedValue::i32_ne, i32), + (Instruction::BranchI32LtSImm, execute_branch_i32_lt_s_imm, UntypedValue::i32_lt_s, i32), + (Instruction::BranchI32LtUImm, execute_branch_i32_lt_u_imm, UntypedValue::i32_lt_u, u32), + (Instruction::BranchI32LeSImm, execute_branch_i32_le_s_imm, UntypedValue::i32_le_s, i32), + (Instruction::BranchI32LeUImm, execute_branch_i32_le_u_imm, UntypedValue::i32_le_u, u32), + (Instruction::BranchI32GtSImm, execute_branch_i32_gt_s_imm, UntypedValue::i32_gt_s, i32), + (Instruction::BranchI32GtUImm, execute_branch_i32_gt_u_imm, UntypedValue::i32_gt_u, u32), + (Instruction::BranchI32GeSImm, execute_branch_i32_ge_s_imm, UntypedValue::i32_ge_s, i32), + (Instruction::BranchI32GeUImm, execute_branch_i32_ge_u_imm, UntypedValue::i32_ge_u, u32), + + (Instruction::BranchI64EqImm, execute_branch_i64_eq_imm, UntypedValue::i64_eq, i64), + (Instruction::BranchI64NeImm, execute_branch_i64_ne_imm, UntypedValue::i64_ne, i64), + (Instruction::BranchI64LtSImm, execute_branch_i64_lt_s_imm, UntypedValue::i64_lt_s, i64), + (Instruction::BranchI64LtUImm, execute_branch_i64_lt_u_imm, UntypedValue::i64_lt_u, u64), + (Instruction::BranchI64LeSImm, execute_branch_i64_le_s_imm, UntypedValue::i64_le_s, i64), + (Instruction::BranchI64LeUImm, execute_branch_i64_le_u_imm, UntypedValue::i64_le_u, u64), + (Instruction::BranchI64GtSImm, execute_branch_i64_gt_s_imm, UntypedValue::i64_gt_s, i64), + (Instruction::BranchI64GtUImm, execute_branch_i64_gt_u_imm, UntypedValue::i64_gt_u, u64), + (Instruction::BranchI64GeSImm, execute_branch_i64_ge_s_imm, UntypedValue::i64_ge_s, i64), + (Instruction::BranchI64GeUImm, execute_branch_i64_ge_u_imm, UntypedValue::i64_ge_u, u64), +} diff --git a/crates/wasmi/src/engine/regmach/tests/op/cmp_br.rs b/crates/wasmi/src/engine/regmach/tests/op/cmp_br.rs new file mode 100644 index 0000000000..f408f197ca --- /dev/null +++ b/crates/wasmi/src/engine/regmach/tests/op/cmp_br.rs @@ -0,0 +1,438 @@ +use super::{wasm_type::WasmType, *}; +use crate::{ + core::ValueType, + engine::{ + bytecode::{BranchOffset, GlobalIdx}, + regmach::{bytecode::BranchOffset16, tests::display_wasm::DisplayValueType}, + }, +}; +use std::fmt::{Debug, Display}; + +#[test] +#[cfg_attr(miri, ignore)] +fn loop_backward() { + fn test_for( + ty: ValueType, + op: &str, + expect_instr: fn(Register, Register, BranchOffset16) -> Instruction, + ) { + let ty = DisplayValueType::from(ty); + let wasm = wat2wasm(&format!( + r" + (module + (func (param {ty} {ty}) + (loop + (local.get 0) + (local.get 1) + ({ty}.{op}) + (br_if 0) + ) + ) + )", + )); + TranslationTest::new(wasm) + .expect_func_instrs([ + expect_instr( + Register::from_i16(0), + Register::from_i16(1), + BranchOffset16::from(0), + ), + Instruction::Return, + ]) + .run() + } + + test_for(ValueType::I32, "eq", Instruction::branch_i32_eq); + test_for(ValueType::I32, "ne", Instruction::branch_i32_ne); + test_for(ValueType::I32, "lt_s", Instruction::branch_i32_lt_s); + test_for(ValueType::I32, "lt_u", Instruction::branch_i32_lt_u); + test_for(ValueType::I32, "le_s", Instruction::branch_i32_le_s); + test_for(ValueType::I32, "le_u", Instruction::branch_i32_le_u); + test_for(ValueType::I32, "gt_s", Instruction::branch_i32_gt_s); + test_for(ValueType::I32, "gt_u", Instruction::branch_i32_gt_u); + test_for(ValueType::I32, "ge_s", Instruction::branch_i32_ge_s); + test_for(ValueType::I32, "ge_u", Instruction::branch_i32_ge_u); + + test_for(ValueType::I64, "eq", Instruction::branch_i64_eq); + test_for(ValueType::I64, "ne", Instruction::branch_i64_ne); + test_for(ValueType::I64, "lt_s", Instruction::branch_i64_lt_s); + test_for(ValueType::I64, "lt_u", Instruction::branch_i64_lt_u); + test_for(ValueType::I64, "le_s", Instruction::branch_i64_le_s); + test_for(ValueType::I64, "le_u", Instruction::branch_i64_le_u); + test_for(ValueType::I64, "gt_s", Instruction::branch_i64_gt_s); + test_for(ValueType::I64, "gt_u", Instruction::branch_i64_gt_u); + test_for(ValueType::I64, "ge_s", Instruction::branch_i64_ge_s); + test_for(ValueType::I64, "ge_u", Instruction::branch_i64_ge_u); + + test_for(ValueType::F32, "eq", Instruction::branch_f32_eq); + test_for(ValueType::F32, "ne", Instruction::branch_f32_ne); + test_for(ValueType::F32, "lt", Instruction::branch_f32_lt); + test_for(ValueType::F32, "le", Instruction::branch_f32_le); + test_for(ValueType::F32, "gt", Instruction::branch_f32_gt); + test_for(ValueType::F32, "ge", Instruction::branch_f32_ge); + + test_for(ValueType::F64, "eq", Instruction::branch_f64_eq); + test_for(ValueType::F64, "ne", Instruction::branch_f64_ne); + test_for(ValueType::F64, "lt", Instruction::branch_f64_lt); + test_for(ValueType::F64, "le", Instruction::branch_f64_le); + test_for(ValueType::F64, "gt", Instruction::branch_f64_gt); + test_for(ValueType::F64, "ge", Instruction::branch_f64_ge); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn loop_backward_imm() { + fn test_for( + op: &str, + value: T, + expect_instr: fn(Register, Const16, BranchOffset16) -> Instruction, + ) where + T: WasmType, + Const16: TryFrom + Debug, + DisplayWasm: Display, + { + let ty = T::NAME; + let display_value = DisplayWasm::from(value); + let wasm = wat2wasm(&format!( + r" + (module + (func (param {ty} {ty}) + (loop + (local.get 0) + ({ty}.const {display_value}) + ({ty}.{op}) + (br_if 0) + ) + ) + )", + )); + TranslationTest::new(wasm) + .expect_func_instrs([ + expect_instr( + Register::from_i16(0), + >::try_from(value).ok().unwrap(), + BranchOffset16::from(0), + ), + Instruction::Return, + ]) + .run() + } + test_for::("eq", 1, Instruction::branch_i32_eq_imm); + test_for::("ne", 1, Instruction::branch_i32_ne_imm); + test_for::("lt_s", 1, Instruction::branch_i32_lt_s_imm); + test_for::("lt_u", 1, Instruction::branch_i32_lt_u_imm); + test_for::("le_s", 1, Instruction::branch_i32_le_s_imm); + test_for::("le_u", 1, Instruction::branch_i32_le_u_imm); + test_for::("gt_s", 1, Instruction::branch_i32_gt_s_imm); + test_for::("gt_u", 1, Instruction::branch_i32_gt_u_imm); + test_for::("ge_s", 1, Instruction::branch_i32_ge_s_imm); + test_for::("ge_u", 1, Instruction::branch_i32_ge_u_imm); + + test_for::("eq", 1, Instruction::branch_i64_eq_imm); + test_for::("ne", 1, Instruction::branch_i64_ne_imm); + test_for::("lt_s", 1, Instruction::branch_i64_lt_s_imm); + test_for::("lt_u", 1, Instruction::branch_i64_lt_u_imm); + test_for::("le_s", 1, Instruction::branch_i64_le_s_imm); + test_for::("le_u", 1, Instruction::branch_i64_le_u_imm); + test_for::("gt_s", 1, Instruction::branch_i64_gt_s_imm); + test_for::("gt_u", 1, Instruction::branch_i64_gt_u_imm); + test_for::("ge_s", 1, Instruction::branch_i64_ge_s_imm); + test_for::("ge_u", 1, Instruction::branch_i64_ge_u_imm); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn loop_backward_imm_eqz() { + fn test_for(op: &str, expect_instr: fn(Register, BranchOffset) -> Instruction) { + let wasm = wat2wasm(&format!( + r" + (module + (func (param i32 i32) + (loop + (local.get 0) + (i32.const 0) + (i32.{op}) + (br_if 0) + ) + ) + )", + )); + TranslationTest::new(wasm) + .expect_func_instrs([ + expect_instr(Register::from_i16(0), BranchOffset::from(0_i32)), + Instruction::Return, + ]) + .run() + } + test_for("eq", Instruction::branch_eqz); + test_for("ne", Instruction::branch_nez); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn block_forward() { + fn test_for( + ty: ValueType, + op: &str, + expect_instr: fn(Register, Register, BranchOffset16) -> Instruction, + ) { + let ty = DisplayValueType::from(ty); + let wasm = wat2wasm(&format!( + r" + (module + (func (param {ty} {ty}) + (block + (local.get 0) + (local.get 1) + ({ty}.{op}) + (br_if 0) + ) + ) + )", + )); + TranslationTest::new(wasm) + .expect_func_instrs([ + expect_instr( + Register::from_i16(0), + Register::from_i16(1), + BranchOffset16::from(1), + ), + Instruction::Return, + ]) + .run() + } + + test_for(ValueType::I32, "eq", Instruction::branch_i32_eq); + test_for(ValueType::I32, "ne", Instruction::branch_i32_ne); + test_for(ValueType::I32, "lt_s", Instruction::branch_i32_lt_s); + test_for(ValueType::I32, "lt_u", Instruction::branch_i32_lt_u); + test_for(ValueType::I32, "le_s", Instruction::branch_i32_le_s); + test_for(ValueType::I32, "le_u", Instruction::branch_i32_le_u); + test_for(ValueType::I32, "gt_s", Instruction::branch_i32_gt_s); + test_for(ValueType::I32, "gt_u", Instruction::branch_i32_gt_u); + test_for(ValueType::I32, "ge_s", Instruction::branch_i32_ge_s); + test_for(ValueType::I32, "ge_u", Instruction::branch_i32_ge_u); + + test_for(ValueType::I64, "eq", Instruction::branch_i64_eq); + test_for(ValueType::I64, "ne", Instruction::branch_i64_ne); + test_for(ValueType::I64, "lt_s", Instruction::branch_i64_lt_s); + test_for(ValueType::I64, "lt_u", Instruction::branch_i64_lt_u); + test_for(ValueType::I64, "le_s", Instruction::branch_i64_le_s); + test_for(ValueType::I64, "le_u", Instruction::branch_i64_le_u); + test_for(ValueType::I64, "gt_s", Instruction::branch_i64_gt_s); + test_for(ValueType::I64, "gt_u", Instruction::branch_i64_gt_u); + test_for(ValueType::I64, "ge_s", Instruction::branch_i64_ge_s); + test_for(ValueType::I64, "ge_u", Instruction::branch_i64_ge_u); + + test_for(ValueType::F32, "eq", Instruction::branch_f32_eq); + test_for(ValueType::F32, "ne", Instruction::branch_f32_ne); + test_for(ValueType::F32, "lt", Instruction::branch_f32_lt); + test_for(ValueType::F32, "le", Instruction::branch_f32_le); + test_for(ValueType::F32, "gt", Instruction::branch_f32_gt); + test_for(ValueType::F32, "ge", Instruction::branch_f32_ge); + + test_for(ValueType::F64, "eq", Instruction::branch_f64_eq); + test_for(ValueType::F64, "ne", Instruction::branch_f64_ne); + test_for(ValueType::F64, "lt", Instruction::branch_f64_lt); + test_for(ValueType::F64, "le", Instruction::branch_f64_le); + test_for(ValueType::F64, "gt", Instruction::branch_f64_gt); + test_for(ValueType::F64, "ge", Instruction::branch_f64_ge); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn block_forward_nop_copy() { + fn test_for( + ty: ValueType, + op: &str, + expect_instr: fn(Register, Register, BranchOffset16) -> Instruction, + ) { + let ty = DisplayValueType::from(ty); + let wasm = wat2wasm(&format!( + r" + (module + (global $g (mut {ty}) ({ty}.const 10)) + (func (param {ty} {ty}) (result {ty}) + (global.get $g) + (block (param {ty}) (result {ty}) + (local.get 0) + (local.get 1) + ({ty}.{op}) + (br_if 0) + (drop) + (local.get 0) + ) + ) + )", + )); + TranslationTest::new(wasm) + .expect_func_instrs([ + Instruction::global_get(Register::from_i16(2), GlobalIdx::from(0)), + expect_instr( + Register::from_i16(0), + Register::from_i16(1), + BranchOffset16::from(2), + ), + Instruction::copy(Register::from_i16(2), Register::from_i16(0)), + Instruction::return_reg(2), + ]) + .run() + } + + test_for(ValueType::I32, "eq", Instruction::branch_i32_eq); + test_for(ValueType::I32, "ne", Instruction::branch_i32_ne); + test_for(ValueType::I32, "lt_s", Instruction::branch_i32_lt_s); + test_for(ValueType::I32, "lt_u", Instruction::branch_i32_lt_u); + test_for(ValueType::I32, "le_s", Instruction::branch_i32_le_s); + test_for(ValueType::I32, "le_u", Instruction::branch_i32_le_u); + test_for(ValueType::I32, "gt_s", Instruction::branch_i32_gt_s); + test_for(ValueType::I32, "gt_u", Instruction::branch_i32_gt_u); + test_for(ValueType::I32, "ge_s", Instruction::branch_i32_ge_s); + test_for(ValueType::I32, "ge_u", Instruction::branch_i32_ge_u); + + test_for(ValueType::I64, "eq", Instruction::branch_i64_eq); + test_for(ValueType::I64, "ne", Instruction::branch_i64_ne); + test_for(ValueType::I64, "lt_s", Instruction::branch_i64_lt_s); + test_for(ValueType::I64, "lt_u", Instruction::branch_i64_lt_u); + test_for(ValueType::I64, "le_s", Instruction::branch_i64_le_s); + test_for(ValueType::I64, "le_u", Instruction::branch_i64_le_u); + test_for(ValueType::I64, "gt_s", Instruction::branch_i64_gt_s); + test_for(ValueType::I64, "gt_u", Instruction::branch_i64_gt_u); + test_for(ValueType::I64, "ge_s", Instruction::branch_i64_ge_s); + test_for(ValueType::I64, "ge_u", Instruction::branch_i64_ge_u); + + test_for(ValueType::F32, "eq", Instruction::branch_f32_eq); + test_for(ValueType::F32, "ne", Instruction::branch_f32_ne); + test_for(ValueType::F32, "lt", Instruction::branch_f32_lt); + test_for(ValueType::F32, "le", Instruction::branch_f32_le); + test_for(ValueType::F32, "gt", Instruction::branch_f32_gt); + test_for(ValueType::F32, "ge", Instruction::branch_f32_ge); + + test_for(ValueType::F64, "eq", Instruction::branch_f64_eq); + test_for(ValueType::F64, "ne", Instruction::branch_f64_ne); + test_for(ValueType::F64, "lt", Instruction::branch_f64_lt); + test_for(ValueType::F64, "le", Instruction::branch_f64_le); + test_for(ValueType::F64, "gt", Instruction::branch_f64_gt); + test_for(ValueType::F64, "ge", Instruction::branch_f64_ge); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn if_forward_multi_value() { + fn test_for( + ty: ValueType, + op: &str, + expect_instr: fn(Register, Register, BranchOffset16) -> Instruction, + ) { + let ty = DisplayValueType::from(ty); + let wasm = wat2wasm(&format!( + r" + (module + (func (param {ty} {ty}) (result {ty}) + (block (result {ty}) + (local.get 0) ;; returned from block if `local.get 0 != 0` + (local.get 0) + (local.get 1) + ({ty}.{op}) + (br_if 0) + (drop) + (local.get 1) ;; returned from block if `local.get 0 == 0` + ) + ) + )", + )); + TranslationTest::new(wasm) + .expect_func_instrs([ + expect_instr( + Register::from_i16(0), + Register::from_i16(1), + BranchOffset16::from(3), + ), + Instruction::copy(Register::from_i16(2), Register::from_i16(0)), + Instruction::branch(BranchOffset::from(2)), + Instruction::copy(Register::from_i16(2), Register::from_i16(1)), + Instruction::return_reg(2), + ]) + .run() + } + + test_for(ValueType::I32, "eq", Instruction::branch_i32_ne); + test_for(ValueType::I32, "ne", Instruction::branch_i32_eq); + test_for(ValueType::I32, "lt_s", Instruction::branch_i32_ge_s); + test_for(ValueType::I32, "lt_u", Instruction::branch_i32_ge_u); + test_for(ValueType::I32, "le_s", Instruction::branch_i32_gt_s); + test_for(ValueType::I32, "le_u", Instruction::branch_i32_gt_u); + test_for(ValueType::I32, "gt_s", Instruction::branch_i32_le_s); + test_for(ValueType::I32, "gt_u", Instruction::branch_i32_le_u); + test_for(ValueType::I32, "ge_s", Instruction::branch_i32_lt_s); + test_for(ValueType::I32, "ge_u", Instruction::branch_i32_lt_u); + + test_for(ValueType::I64, "eq", Instruction::branch_i64_ne); + test_for(ValueType::I64, "ne", Instruction::branch_i64_eq); + test_for(ValueType::I64, "lt_s", Instruction::branch_i64_ge_s); + test_for(ValueType::I64, "lt_u", Instruction::branch_i64_ge_u); + test_for(ValueType::I64, "le_s", Instruction::branch_i64_gt_s); + test_for(ValueType::I64, "le_u", Instruction::branch_i64_gt_u); + test_for(ValueType::I64, "gt_s", Instruction::branch_i64_le_s); + test_for(ValueType::I64, "gt_u", Instruction::branch_i64_le_u); + test_for(ValueType::I64, "ge_s", Instruction::branch_i64_lt_s); + test_for(ValueType::I64, "ge_u", Instruction::branch_i64_lt_u); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn if_forward() { + fn test_for( + ty: ValueType, + op: &str, + expect_instr: fn(Register, Register, BranchOffset16) -> Instruction, + ) { + let ty = DisplayValueType::from(ty); + let wasm = wat2wasm(&format!( + r" + (module + (func (param {ty} {ty}) + (if + ({ty}.{op} + (local.get 0) + (local.get 1) + ) + (then) + ) + ) + )", + )); + TranslationTest::new(wasm) + .expect_func_instrs([ + expect_instr( + Register::from_i16(0), + Register::from_i16(1), + BranchOffset16::from(1), + ), + Instruction::Return, + ]) + .run() + } + + test_for(ValueType::I32, "eq", Instruction::branch_i32_ne); + test_for(ValueType::I32, "ne", Instruction::branch_i32_eq); + test_for(ValueType::I32, "lt_s", Instruction::branch_i32_ge_s); + test_for(ValueType::I32, "lt_u", Instruction::branch_i32_ge_u); + test_for(ValueType::I32, "le_s", Instruction::branch_i32_gt_s); + test_for(ValueType::I32, "le_u", Instruction::branch_i32_gt_u); + test_for(ValueType::I32, "gt_s", Instruction::branch_i32_le_s); + test_for(ValueType::I32, "gt_u", Instruction::branch_i32_le_u); + test_for(ValueType::I32, "ge_s", Instruction::branch_i32_lt_s); + test_for(ValueType::I32, "ge_u", Instruction::branch_i32_lt_u); + + test_for(ValueType::I64, "eq", Instruction::branch_i64_ne); + test_for(ValueType::I64, "ne", Instruction::branch_i64_eq); + test_for(ValueType::I64, "lt_s", Instruction::branch_i64_ge_s); + test_for(ValueType::I64, "lt_u", Instruction::branch_i64_ge_u); + test_for(ValueType::I64, "le_s", Instruction::branch_i64_gt_s); + test_for(ValueType::I64, "le_u", Instruction::branch_i64_gt_u); + test_for(ValueType::I64, "gt_s", Instruction::branch_i64_le_s); + test_for(ValueType::I64, "gt_u", Instruction::branch_i64_le_u); + test_for(ValueType::I64, "ge_s", Instruction::branch_i64_lt_s); + test_for(ValueType::I64, "ge_u", Instruction::branch_i64_lt_u); +} diff --git a/crates/wasmi/src/engine/regmach/tests/op/mod.rs b/crates/wasmi/src/engine/regmach/tests/op/mod.rs index b2b18e0a06..0eab96f2d2 100644 --- a/crates/wasmi/src/engine/regmach/tests/op/mod.rs +++ b/crates/wasmi/src/engine/regmach/tests/op/mod.rs @@ -5,6 +5,7 @@ mod br_if; mod br_table; mod call; mod cmp; +mod cmp_br; mod global_get; mod global_set; mod if_; @@ -48,12 +49,24 @@ use super::{ WasmType, }; -/// Creates an [`Const32`] from the given `i64` value. +/// Creates an [`Const32`] from the given `i32` value. /// /// # Panics /// /// If the `value` cannot be converted into `i32` losslessly. #[track_caller] +#[allow(dead_code)] +fn i32imm16(value: i32) -> Const16 { + >::from_i32(value) + .unwrap_or_else(|| panic!("value must be 16-bit encodable: {}", value)) +} + +/// Creates an [`Const32`] from the given `u32` value. +/// +/// # Panics +/// +/// If the `value` cannot be converted into `u32` losslessly. +#[track_caller] fn u32imm16(value: u32) -> Const16 { >::from_u32(value) .unwrap_or_else(|| panic!("value must be 16-bit encodable: {}", value)) diff --git a/crates/wasmi/src/engine/regmach/tests/wasm_type.rs b/crates/wasmi/src/engine/regmach/tests/wasm_type.rs index df293e11a4..ee07a4cc25 100644 --- a/crates/wasmi/src/engine/regmach/tests/wasm_type.rs +++ b/crates/wasmi/src/engine/regmach/tests/wasm_type.rs @@ -13,6 +13,15 @@ pub trait WasmType: Copy + Display + Into + From { fn return_imm_instr(&self) -> Instruction; } +impl WasmType for u32 { + const NAME: &'static str = "i32"; + const VALUE_TYPE: ValueType = ValueType::I32; + + fn return_imm_instr(&self) -> Instruction { + Instruction::return_imm32(*self) + } +} + impl WasmType for i32 { const NAME: &'static str = "i32"; const VALUE_TYPE: ValueType = ValueType::I32; @@ -22,6 +31,18 @@ impl WasmType for i32 { } } +impl WasmType for u64 { + const NAME: &'static str = "i64"; + const VALUE_TYPE: ValueType = ValueType::I64; + + fn return_imm_instr(&self) -> Instruction { + match >::from_i64(*self as i64) { + Some(value) => Instruction::return_i64imm32(value), + None => Instruction::return_reg(Register::from_i16(-1)), + } + } +} + impl WasmType for i64 { const NAME: &'static str = "i64"; const VALUE_TYPE: ValueType = ValueType::I64; diff --git a/crates/wasmi/src/engine/regmach/translator/instr_encoder.rs b/crates/wasmi/src/engine/regmach/translator/instr_encoder.rs index 174f0d16fd..0a522747a9 100644 --- a/crates/wasmi/src/engine/regmach/translator/instr_encoder.rs +++ b/crates/wasmi/src/engine/regmach/translator/instr_encoder.rs @@ -7,14 +7,26 @@ use crate::{ Instr, }, regmach::{ - bytecode::{Const32, Instruction, Provider, Register, RegisterSpan, RegisterSpanIter}, - translator::ValueStack, + bytecode::{ + BinInstr, + BinInstrImm16, + BranchOffset16, + Const16, + Const32, + Instruction, + Provider, + Register, + RegisterSpan, + RegisterSpanIter, + }, + translator::{stack::RegisterSpace, ValueStack}, }, TranslationError, }, module::ModuleResources, }; use alloc::vec::{Drain, Vec}; +use core::mem; use wasmi_core::{UntypedValue, ValueType, F32}; /// Encodes `wasmi` bytecode instructions to an [`Instruction`] stream. @@ -207,7 +219,7 @@ impl InstrEncoder { /// If this is used before all branching labels have been pinned. pub fn update_branch_offsets(&mut self) -> Result<(), TranslationError> { for (user, offset) in self.labels.resolved_users() { - self.instrs.get_mut(user).update_branch_offset(offset?); + self.instrs.get_mut(user).update_branch_offset(offset?)?; } Ok(()) } @@ -615,27 +627,38 @@ impl InstrEncoder { /// to the same basic block. pub fn encode_local_set( &mut self, + stack: &mut ValueStack, res: &ModuleResources, local: Register, value: Register, ) -> Result<(), TranslationError> { - if let Some(last_instr) = self.last_instr { - if let Some(result) = self.instrs.get_mut(last_instr).result_mut(res) { - // Case: we can replace the `result` register of the previous - // instruction instead of emitting a copy instruction. - if *result == value { - // TODO: Find out in what cases `result != value`. Is this a bug or an edge case? - // Generally `result` should be equal to `value` since `value` refers to the - // `result` of the previous instruction. - // Therefore, instead of an `if` we originally had a `debug_assert`. - // (Note: the spidermonkey bench test failed without this change.) - *result = local; - return Ok(()); - } - } + /// Fallback for when we need to encode a `copy` instruction to encode the `local.set` or `local.tee`. + fn fallback_copy( + this: &mut InstrEncoder, + local: Register, + value: Register, + ) -> Result<(), TranslationError> { + this.push_instr(Instruction::copy(local, value))?; + Ok(()) + } + let Some(last_instr) = self.last_instr else { + return fallback_copy(self, local, value); + }; + let Some(result) = self.instrs.get_mut(last_instr).result_mut(res) else { + return fallback_copy(self, local, value); + }; + if matches!(stack.get_register_space(*result), RegisterSpace::Local) { + return fallback_copy(self, local, value); } - // Case: we need to encode a copy instruction to encode the `local.set` or `local.tee`. - self.push_instr(Instruction::copy(local, value))?; + if *result != value { + // TODO: Find out in what cases `result != value`. Is this a bug or an edge case? + // Generally `result` should be equal to `value` since `value` refers to the + // `result` of the previous instruction. + // Therefore, instead of an `if` we originally had a `debug_assert`. + // (Note: the spidermonkey bench test failed without this change.) + return fallback_copy(self, local, value); + } + *result = local; Ok(()) } @@ -676,6 +699,315 @@ impl InstrEncoder { } Ok(()) } + + /// Encodes a `branch_eqz` instruction and tries to fuse it with a previous comparison instruction. + pub fn encode_branch_eqz( + &mut self, + stack: &mut ValueStack, + condition: Register, + label: LabelRef, + ) -> Result<(), TranslationError> { + type BranchCmpConstructor = fn(Register, Register, BranchOffset16) -> Instruction; + type BranchCmpImmConstructor = fn(Register, Const16, BranchOffset16) -> Instruction; + + /// Encode an unoptimized `branch_eqz` instruction. + /// + /// This is used as fallback whenever fusing compare and branch instructions is not possible. + fn encode_branch_eqz_fallback( + this: &mut InstrEncoder, + condition: Register, + label: LabelRef, + ) -> Result<(), TranslationError> { + let offset = this.try_resolve_label(label)?; + this.push_instr(Instruction::branch_eqz(condition, offset))?; + Ok(()) + } + + /// Create a fused cmp+branch instruction and wrap it in a `Some`. + /// + /// We wrap the returned value in `Some` to unify handling of a bunch of cases. + fn fuse( + this: &mut InstrEncoder, + stack: &mut ValueStack, + last_instr: Instr, + instr: BinInstr, + label: LabelRef, + make_instr: BranchCmpConstructor, + ) -> Result, TranslationError> { + if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { + // We need to filter out instructions that store their result + // into a local register slot because they introduce observable behavior + // which a fused cmp+branch instruction would remove. + return Ok(None); + } + let offset = this.try_resolve_label_for(label, last_instr)?; + let instr = BranchOffset16::new(offset) + .map(|offset16| make_instr(instr.lhs, instr.rhs, offset16)); + Ok(instr) + } + + /// Create a fused cmp+branch instruction with a 16-bit immediate and wrap it in a `Some`. + /// + /// We wrap the returned value in `Some` to unify handling of a bunch of cases. + fn fuse_imm( + this: &mut InstrEncoder, + stack: &mut ValueStack, + last_instr: Instr, + instr: BinInstrImm16, + label: LabelRef, + make_instr: BranchCmpImmConstructor, + ) -> Result, TranslationError> { + if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { + // We need to filter out instructions that store their result + // into a local register slot because they introduce observable behavior + // which a fused cmp+branch instruction would remove. + return Ok(None); + } + let offset = this.try_resolve_label_for(label, last_instr)?; + let instr = BranchOffset16::new(offset) + .map(|offset16| make_instr(instr.reg_in, instr.imm_in, offset16)); + Ok(instr) + } + use Instruction as I; + + let Some(last_instr) = self.last_instr else { + return encode_branch_eqz_fallback(self, condition, label); + }; + + #[rustfmt::skip] + let fused_instr = match *self.instrs.get(last_instr) { + I::I32EqImm16(instr) if instr.imm_in.is_zero() => { + match stack.get_register_space(instr.result) { + RegisterSpace::Local => None, + _ => { + // Note: unfortunately we cannot apply this optimization for `i64` variants + // since the standard `branch_eqz` assumes its operands to be of type `i32`. + let offset32 = self.try_resolve_label_for(label, last_instr)?; + Some(Instruction::branch_nez(instr.reg_in, offset32)) + } + } + } + I::I32NeImm16(instr) if instr.imm_in.is_zero() => { + match stack.get_register_space(instr.result) { + RegisterSpace::Local => None, + _ => { + // Note: unfortunately we cannot apply this optimization for `i64` variants + // since the standard `branch_nez` assumes its operands to be of type `i32`. + let offset32 = self.try_resolve_label_for(label, last_instr)?; + Some(Instruction::branch_eqz(instr.reg_in, offset32)) + } + } + } + I::I32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ne as _)?, + I::I32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_eq as _)?, + I::I32LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_s as _)?, + I::I32LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_u as _)?, + I::I32LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_s as _)?, + I::I32LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_u as _)?, + I::I32GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_s as _)?, + I::I32GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_u as _)?, + I::I32GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_s as _)?, + I::I32GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_u as _)?, + I::I64Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ne as _)?, + I::I64Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_eq as _)?, + I::I64LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_s as _)?, + I::I64LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_u as _)?, + I::I64LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_s as _)?, + I::I64LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_u as _)?, + I::I64GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_s as _)?, + I::I64GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_u as _)?, + I::I64GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_s as _)?, + I::I64GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_u as _)?, + I::F32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_ne as _)?, + I::F32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_eq as _)?, + // Note: We cannot fuse cmp+branch for float comparison operators due to how NaN values are treated. + I::I32EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ne_imm as _)?, + I::I32NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_eq_imm as _)?, + I::I32LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_s_imm as _)?, + I::I32LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_u_imm as _)?, + I::I32LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_s_imm as _)?, + I::I32LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_u_imm as _)?, + I::I32GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_s_imm as _)?, + I::I32GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_u_imm as _)?, + I::I32GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_s_imm as _)?, + I::I32GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_u_imm as _)?, + I::I64EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ne_imm as _)?, + I::I64NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_eq_imm as _)?, + I::I64LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_s_imm as _)?, + I::I64LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_u_imm as _)?, + I::I64LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_s_imm as _)?, + I::I64LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_u_imm as _)?, + I::I64GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_s_imm as _)?, + I::I64GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_u_imm as _)?, + I::I64GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_s_imm as _)?, + I::I64GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_u_imm as _)?, + _ => None, + }; + if let Some(fused_instr) = fused_instr { + _ = mem::replace(self.instrs.get_mut(last_instr), fused_instr); + return Ok(()); + } + encode_branch_eqz_fallback(self, condition, label) + } + + /// Encodes a `branch_nez` instruction and tries to fuse it with a previous comparison instruction. + pub fn encode_branch_nez( + &mut self, + stack: &mut ValueStack, + condition: Register, + label: LabelRef, + ) -> Result<(), TranslationError> { + type BranchCmpConstructor = fn(Register, Register, BranchOffset16) -> Instruction; + type BranchCmpImmConstructor = fn(Register, Const16, BranchOffset16) -> Instruction; + + /// Encode an unoptimized `branch_nez` instruction. + /// + /// This is used as fallback whenever fusing compare and branch instructions is not possible. + fn encode_branch_nez_fallback( + this: &mut InstrEncoder, + condition: Register, + label: LabelRef, + ) -> Result<(), TranslationError> { + let offset = this.try_resolve_label(label)?; + this.push_instr(Instruction::branch_nez(condition, offset))?; + Ok(()) + } + + /// Create a fused cmp+branch instruction and wrap it in a `Some`. + /// + /// We wrap the returned value in `Some` to unify handling of a bunch of cases. + fn fuse( + this: &mut InstrEncoder, + stack: &mut ValueStack, + last_instr: Instr, + instr: BinInstr, + label: LabelRef, + make_instr: BranchCmpConstructor, + ) -> Result, TranslationError> { + if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { + // We need to filter out instructions that store their result + // into a local register slot because they introduce observable behavior + // which a fused cmp+branch instruction would remove. + return Ok(None); + } + let offset = this.try_resolve_label_for(label, last_instr)?; + let instr = BranchOffset16::new(offset) + .map(|offset16| make_instr(instr.lhs, instr.rhs, offset16)); + Ok(instr) + } + + /// Create a fused cmp+branch instruction with a 16-bit immediate and wrap it in a `Some`. + /// + /// We wrap the returned value in `Some` to unify handling of a bunch of cases. + fn fuse_imm( + this: &mut InstrEncoder, + stack: &mut ValueStack, + last_instr: Instr, + instr: BinInstrImm16, + label: LabelRef, + make_instr: BranchCmpImmConstructor, + ) -> Result, TranslationError> { + if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { + // We need to filter out instructions that store their result + // into a local register slot because they introduce observable behavior + // which a fused cmp+branch instruction would remove. + return Ok(None); + } + let offset = this.try_resolve_label_for(label, last_instr)?; + let instr = BranchOffset16::new(offset) + .map(|offset16| make_instr(instr.reg_in, instr.imm_in, offset16)); + Ok(instr) + } + use Instruction as I; + + let Some(last_instr) = self.last_instr else { + return encode_branch_nez_fallback(self, condition, label); + }; + + #[rustfmt::skip] + let fused_instr = match *self.instrs.get(last_instr) { + I::I32EqImm16(instr) if instr.imm_in.is_zero() => { + match stack.get_register_space(instr.result) { + RegisterSpace::Local => None, + _ => { + // Note: unfortunately we cannot apply this optimization for `i64` variants + // since the standard `branch_eqz` assumes its operands to be of type `i32`. + let offset32 = self.try_resolve_label_for(label, last_instr)?; + Some(Instruction::branch_eqz(instr.reg_in, offset32)) + } + } + } + I::I32NeImm16(instr) if instr.imm_in.is_zero() => { + match stack.get_register_space(instr.result) { + RegisterSpace::Local => None, + _ => { + // Note: unfortunately we cannot apply this optimization for `i64` variants + // since the standard `branch_nez` assumes its operands to be of type `i32`. + let offset32 = self.try_resolve_label_for(label, last_instr)?; + Some(Instruction::branch_nez(instr.reg_in, offset32)) + } + } + } + I::I32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_eq as _)?, + I::I32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ne as _)?, + I::I32LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_s as _)?, + I::I32LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_u as _)?, + I::I32LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_s as _)?, + I::I32LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_u as _)?, + I::I32GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_s as _)?, + I::I32GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_u as _)?, + I::I32GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_s as _)?, + I::I32GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_u as _)?, + I::I64Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_eq as _)?, + I::I64Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ne as _)?, + I::I64LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_s as _)?, + I::I64LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_u as _)?, + I::I64LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_s as _)?, + I::I64LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_u as _)?, + I::I64GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_s as _)?, + I::I64GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_u as _)?, + I::I64GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_s as _)?, + I::I64GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_u as _)?, + I::F32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_eq as _)?, + I::F32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_ne as _)?, + I::F32Lt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_lt as _)?, + I::F32Le(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_le as _)?, + I::F32Gt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_gt as _)?, + I::F32Ge(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_ge as _)?, + I::F64Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_eq as _)?, + I::F64Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_ne as _)?, + I::F64Lt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_lt as _)?, + I::F64Le(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_le as _)?, + I::F64Gt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_gt as _)?, + I::F64Ge(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_ge as _)?, + I::I32EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_eq_imm as _)?, + I::I32NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ne_imm as _)?, + I::I32LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_s_imm as _)?, + I::I32LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_u_imm as _)?, + I::I32LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_s_imm as _)?, + I::I32LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_u_imm as _)?, + I::I32GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_s_imm as _)?, + I::I32GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_u_imm as _)?, + I::I32GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_s_imm as _)?, + I::I32GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_u_imm as _)?, + I::I64EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_eq_imm as _)?, + I::I64NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ne_imm as _)?, + I::I64LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_s_imm as _)?, + I::I64LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_u_imm as _)?, + I::I64LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_s_imm as _)?, + I::I64LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_u_imm as _)?, + I::I64GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_s_imm as _)?, + I::I64GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_u_imm as _)?, + I::I64GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_s_imm as _)?, + I::I64GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_u_imm as _)?, + _ => None, + }; + if let Some(fused_instr) = fused_instr { + _ = mem::replace(self.instrs.get_mut(last_instr), fused_instr); + return Ok(()); + } + encode_branch_nez_fallback(self, condition, label) + } } impl Instruction { @@ -684,11 +1016,69 @@ impl Instruction { /// # Panics /// /// If `self` is not a branch [`Instruction`]. - pub fn update_branch_offset(&mut self, new_offset: BranchOffset) { + pub fn update_branch_offset( + &mut self, + new_offset: BranchOffset, + ) -> Result<(), TranslationError> { match self { Instruction::Branch { offset } | Instruction::BranchEqz { offset, .. } - | Instruction::BranchNez { offset, .. } => offset.init(new_offset), + | Instruction::BranchNez { offset, .. } => { + offset.init(new_offset); + Ok(()) + } + Instruction::BranchI32Eq(instr) + | Instruction::BranchI32Ne(instr) + | Instruction::BranchI32LtS(instr) + | Instruction::BranchI32LtU(instr) + | Instruction::BranchI32LeS(instr) + | Instruction::BranchI32LeU(instr) + | Instruction::BranchI32GtS(instr) + | Instruction::BranchI32GtU(instr) + | Instruction::BranchI32GeS(instr) + | Instruction::BranchI32GeU(instr) + | Instruction::BranchI64Eq(instr) + | Instruction::BranchI64Ne(instr) + | Instruction::BranchI64LtS(instr) + | Instruction::BranchI64LtU(instr) + | Instruction::BranchI64LeS(instr) + | Instruction::BranchI64LeU(instr) + | Instruction::BranchI64GtS(instr) + | Instruction::BranchI64GtU(instr) + | Instruction::BranchI64GeS(instr) + | Instruction::BranchI64GeU(instr) + | Instruction::BranchF32Eq(instr) + | Instruction::BranchF32Ne(instr) + | Instruction::BranchF32Lt(instr) + | Instruction::BranchF32Le(instr) + | Instruction::BranchF32Gt(instr) + | Instruction::BranchF32Ge(instr) + | Instruction::BranchF64Eq(instr) + | Instruction::BranchF64Ne(instr) + | Instruction::BranchF64Lt(instr) + | Instruction::BranchF64Le(instr) + | Instruction::BranchF64Gt(instr) + | Instruction::BranchF64Ge(instr) => instr.offset.init(new_offset), + Instruction::BranchI32EqImm(instr) + | Instruction::BranchI32NeImm(instr) + | Instruction::BranchI32LtSImm(instr) + | Instruction::BranchI32LeSImm(instr) + | Instruction::BranchI32GtSImm(instr) + | Instruction::BranchI32GeSImm(instr) => instr.offset.init(new_offset), + Instruction::BranchI32LtUImm(instr) + | Instruction::BranchI32LeUImm(instr) + | Instruction::BranchI32GtUImm(instr) + | Instruction::BranchI32GeUImm(instr) => instr.offset.init(new_offset), + Instruction::BranchI64EqImm(instr) + | Instruction::BranchI64NeImm(instr) + | Instruction::BranchI64LtSImm(instr) + | Instruction::BranchI64LeSImm(instr) + | Instruction::BranchI64GtSImm(instr) + | Instruction::BranchI64GeSImm(instr) => instr.offset.init(new_offset), + Instruction::BranchI64LtUImm(instr) + | Instruction::BranchI64LeUImm(instr) + | Instruction::BranchI64GtUImm(instr) + | Instruction::BranchI64GeUImm(instr) => instr.offset.init(new_offset), _ => panic!("tried to update branch offset of a non-branch instruction: {self:?}"), } } diff --git a/crates/wasmi/src/engine/regmach/translator/result_mut.rs b/crates/wasmi/src/engine/regmach/translator/result_mut.rs index 4cbcccbd64..f3d4f9c525 100644 --- a/crates/wasmi/src/engine/regmach/translator/result_mut.rs +++ b/crates/wasmi/src/engine/regmach/translator/result_mut.rs @@ -70,6 +70,58 @@ impl Instruction { Instruction::BranchEqz { .. } | Instruction::BranchNez { .. } | Instruction::BranchTable { .. } => None, + Instruction::BranchI32Eq(_) | + Instruction::BranchI32EqImm(_) | + Instruction::BranchI32Ne(_) | + Instruction::BranchI32NeImm(_) | + Instruction::BranchI32LtS(_) | + Instruction::BranchI32LtSImm(_) | + Instruction::BranchI32LtU(_) | + Instruction::BranchI32LtUImm(_) | + Instruction::BranchI32LeS(_) | + Instruction::BranchI32LeSImm(_) | + Instruction::BranchI32LeU(_) | + Instruction::BranchI32LeUImm(_) | + Instruction::BranchI32GtS(_) | + Instruction::BranchI32GtSImm(_) | + Instruction::BranchI32GtU(_) | + Instruction::BranchI32GtUImm(_) | + Instruction::BranchI32GeS(_) | + Instruction::BranchI32GeSImm(_) | + Instruction::BranchI32GeU(_) | + Instruction::BranchI32GeUImm(_) | + Instruction::BranchI64Eq(_) | + Instruction::BranchI64EqImm(_) | + Instruction::BranchI64Ne(_) | + Instruction::BranchI64NeImm(_) | + Instruction::BranchI64LtS(_) | + Instruction::BranchI64LtSImm(_) | + Instruction::BranchI64LtU(_) | + Instruction::BranchI64LtUImm(_) | + Instruction::BranchI64LeS(_) | + Instruction::BranchI64LeSImm(_) | + Instruction::BranchI64LeU(_) | + Instruction::BranchI64LeUImm(_) | + Instruction::BranchI64GtS(_) | + Instruction::BranchI64GtSImm(_) | + Instruction::BranchI64GtU(_) | + Instruction::BranchI64GtUImm(_) | + Instruction::BranchI64GeS(_) | + Instruction::BranchI64GeSImm(_) | + Instruction::BranchI64GeU(_) | + Instruction::BranchI64GeUImm(_) | + Instruction::BranchF32Eq(_) | + Instruction::BranchF32Ne(_) | + Instruction::BranchF32Lt(_) | + Instruction::BranchF32Le(_) | + Instruction::BranchF32Gt(_) | + Instruction::BranchF32Ge(_) | + Instruction::BranchF64Eq(_) | + Instruction::BranchF64Ne(_) | + Instruction::BranchF64Lt(_) | + Instruction::BranchF64Le(_) | + Instruction::BranchF64Gt(_) | + Instruction::BranchF64Ge(_) => None, Instruction::Copy { result, .. } | Instruction::CopyImm32 { result, .. } | Instruction::CopyI64Imm32 { result, .. } | diff --git a/crates/wasmi/src/engine/regmach/translator/stack/mod.rs b/crates/wasmi/src/engine/regmach/translator/stack/mod.rs index b1ac7fab10..cab28a8922 100644 --- a/crates/wasmi/src/engine/regmach/translator/stack/mod.rs +++ b/crates/wasmi/src/engine/regmach/translator/stack/mod.rs @@ -2,11 +2,10 @@ mod consts; mod provider; mod register_alloc; -use self::register_alloc::RegisterSpace; pub use self::{ consts::{FuncLocalConsts, FuncLocalConstsIter}, provider::{ProviderStack, TaggedProvider}, - register_alloc::RegisterAlloc, + register_alloc::{RegisterAlloc, RegisterSpace}, }; use super::TypedValue; use crate::{ @@ -382,4 +381,9 @@ impl ValueStack { pub fn defrag_register(&mut self, register: Register) -> Register { self.reg_alloc.defrag_register(register) } + + /// Returns the [`RegisterSpace`] for the given [`Register`]. + pub fn get_register_space(&self, register: Register) -> RegisterSpace { + self.reg_alloc.register_space(register) + } } diff --git a/crates/wasmi/src/engine/regmach/translator/visit.rs b/crates/wasmi/src/engine/regmach/translator/visit.rs index 0825ab6853..1de6dec6e5 100644 --- a/crates/wasmi/src/engine/regmach/translator/visit.rs +++ b/crates/wasmi/src/engine/regmach/translator/visit.rs @@ -247,10 +247,11 @@ impl<'a> VisitOperator<'a> for FuncTranslator<'a> { .push_else_providers(self.alloc.buffer.iter().copied())?; // Create the `else` label and the conditional branch to `else`. let else_label = self.alloc.instr_encoder.new_label(); - let else_offset = self.alloc.instr_encoder.try_resolve_label(else_label)?; - self.alloc - .instr_encoder - .push_instr(Instruction::branch_eqz(condition, else_offset))?; + self.alloc.instr_encoder.encode_branch_eqz( + &mut self.alloc.stack, + condition, + else_label, + )?; let reachability = IfReachability::both(else_label); // Optionally create the [`Instruction::ConsumeFuel`] for the `then` branch. // @@ -420,11 +421,11 @@ impl<'a> VisitOperator<'a> for FuncTranslator<'a> { if branch_params.is_empty() { // Case: no values need to be copied so we can directly // encode the `br_if` as efficient `branch_nez`. - let branch_offset = - self.alloc.instr_encoder.try_resolve_label(branch_dst)?; - self.alloc - .instr_encoder - .push_instr(Instruction::branch_nez(condition, branch_offset))?; + self.alloc.instr_encoder.encode_branch_nez( + &mut self.alloc.stack, + condition, + branch_dst, + )?; return Ok(()); } self.alloc @@ -442,11 +443,11 @@ impl<'a> VisitOperator<'a> for FuncTranslator<'a> { // no copies are required. // // This means we can encode the `br_if` as efficient `branch_nez`. - let branch_offset = - self.alloc.instr_encoder.try_resolve_label(branch_dst)?; - self.alloc - .instr_encoder - .push_instr(Instruction::branch_nez(condition, branch_offset))?; + self.alloc.instr_encoder.encode_branch_nez( + &mut self.alloc.stack, + condition, + branch_dst, + )?; return Ok(()); } // Case: We need to copy the branch inputs to where the @@ -459,11 +460,11 @@ impl<'a> VisitOperator<'a> for FuncTranslator<'a> { // and finally perform the actual branch to the target // control frame. let skip_label = self.alloc.instr_encoder.new_label(); - let skip_offset = self.alloc.instr_encoder.try_resolve_label(skip_label)?; - debug_assert!(!skip_offset.is_init()); - self.alloc - .instr_encoder - .push_instr(Instruction::branch_eqz(condition, skip_offset))?; + self.alloc.instr_encoder.encode_branch_eqz( + &mut self.alloc.stack, + condition, + skip_label, + )?; self.alloc.instr_encoder.encode_copies( &mut self.alloc.stack, branch_params, @@ -840,9 +841,12 @@ impl<'a> VisitOperator<'a> for FuncTranslator<'a> { // computation which allows us to exchange the result register of // this previous instruction instead of encoding another `copy` // instruction as an optimization. - self.alloc - .instr_encoder - .encode_local_set(&self.res, local_register, value)?; + self.alloc.instr_encoder.encode_local_set( + &mut self.alloc.stack, + &self.res, + local_register, + value, + )?; } } self.alloc.instr_encoder.reset_last_instr(); diff --git a/crates/wasmi/src/engine/regmach/translator/visit_register.rs b/crates/wasmi/src/engine/regmach/translator/visit_register.rs index 94021665c7..c8dd1ea489 100644 --- a/crates/wasmi/src/engine/regmach/translator/visit_register.rs +++ b/crates/wasmi/src/engine/regmach/translator/visit_register.rs @@ -1,6 +1,8 @@ use crate::engine::regmach::bytecode::{ BinInstr, BinInstrImm16, + BranchBinOpInstr, + BranchBinOpInstrImm, Const16, CopysignImmInstr, Instruction, @@ -80,6 +82,60 @@ impl VisitInputRegisters for Instruction { Instruction::BranchEqz { condition, .. } | Instruction::BranchNez { condition, .. } => f(condition), Instruction::BranchTable { index, .. } => f(index), + + Instruction::BranchI32Eq(instr) => instr.visit_input_registers(f), + Instruction::BranchI32EqImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32Ne(instr) => instr.visit_input_registers(f), + Instruction::BranchI32NeImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LtS(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LtSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LtU(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LtUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LeS(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LeSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LeU(instr) => instr.visit_input_registers(f), + Instruction::BranchI32LeUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GtS(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GtSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GtU(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GtUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GeS(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GeSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GeU(instr) => instr.visit_input_registers(f), + Instruction::BranchI32GeUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64Eq(instr) => instr.visit_input_registers(f), + Instruction::BranchI64EqImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64Ne(instr) => instr.visit_input_registers(f), + Instruction::BranchI64NeImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LtS(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LtSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LtU(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LtUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LeS(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LeSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LeU(instr) => instr.visit_input_registers(f), + Instruction::BranchI64LeUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GtS(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GtSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GtU(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GtUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GeS(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GeSImm(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GeU(instr) => instr.visit_input_registers(f), + Instruction::BranchI64GeUImm(instr) => instr.visit_input_registers(f), + Instruction::BranchF32Eq(instr) => instr.visit_input_registers(f), + Instruction::BranchF32Ne(instr) => instr.visit_input_registers(f), + Instruction::BranchF32Lt(instr) => instr.visit_input_registers(f), + Instruction::BranchF32Le(instr) => instr.visit_input_registers(f), + Instruction::BranchF32Gt(instr) => instr.visit_input_registers(f), + Instruction::BranchF32Ge(instr) => instr.visit_input_registers(f), + Instruction::BranchF64Eq(instr) => instr.visit_input_registers(f), + Instruction::BranchF64Ne(instr) => instr.visit_input_registers(f), + Instruction::BranchF64Lt(instr) => instr.visit_input_registers(f), + Instruction::BranchF64Le(instr) => instr.visit_input_registers(f), + Instruction::BranchF64Gt(instr) => instr.visit_input_registers(f), + Instruction::BranchF64Ge(instr) => instr.visit_input_registers(f), + Instruction::Copy { result, value } => { // Note: for copy instruction unlike all other instructions // we need to also visit the result register since @@ -504,6 +560,18 @@ impl LoadOffset16Instr { } } +impl BranchBinOpInstr { + fn visit_input_registers(&mut self, mut f: impl FnMut(&mut Register)) { + visit_registers!(f, &mut self.lhs, &mut self.rhs); + } +} + +impl BranchBinOpInstrImm { + fn visit_input_registers(&mut self, mut f: impl FnMut(&mut Register)) { + f(&mut self.lhs) + } +} + impl VisitInputRegisters for StoreInstr { fn visit_input_registers(&mut self, mut f: impl FnMut(&mut Register)) { f(&mut self.ptr);