diff --git a/internal/engine/wazevo/backend/isa/amd64/instr.go b/internal/engine/wazevo/backend/isa/amd64/instr.go index 41a019d0bb..511299042d 100644 --- a/internal/engine/wazevo/backend/isa/amd64/instr.go +++ b/internal/engine/wazevo/backend/isa/amd64/instr.go @@ -1585,6 +1585,7 @@ var defKinds = [instrMax]defKind{ imm: defKindOp2, unaryRmR: defKindOp2, xmmUnaryRmR: defKindOp2, + xmmRmR: defKindNone, mov64MR: defKindOp2, movsxRmR: defKindOp2, movzxRmR: defKindOp2, @@ -1641,6 +1642,7 @@ var useKinds = [instrMax]useKind{ imm: useKindNone, unaryRmR: useKindOp1, xmmUnaryRmR: useKindOp1, + xmmRmR: useKindOp1Op2Reg, mov64MR: useKindOp1, movzxRmR: useKindOp1, movsxRmR: useKindOp1, diff --git a/internal/engine/wazevo/backend/isa/amd64/instr_encoding.go b/internal/engine/wazevo/backend/isa/amd64/instr_encoding.go index 86015ea0ab..7fa545458d 100644 --- a/internal/engine/wazevo/backend/isa/amd64/instr_encoding.go +++ b/internal/engine/wazevo/backend/isa/amd64/instr_encoding.go @@ -523,6 +523,7 @@ func (i *instruction) encode(c backend.Compiler) (needsLabelResolution bool) { case div: panic("TODO") + case mulHi: var prefix legacyPrefixes rex := rexInfo(0) diff --git a/internal/engine/wazevo/backend/isa/amd64/machine.go b/internal/engine/wazevo/backend/isa/amd64/machine.go index 8cfd5c5231..133eb8a189 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine.go @@ -280,6 +280,14 @@ func (m *machine) LowerInstr(instr *ssa.Instruction) { m.lowerCtz(instr) case ssa.OpcodePopcnt: m.lowerUnaryRmR(instr, unaryRmROpcodePopcnt) + case ssa.OpcodeFadd, ssa.OpcodeFsub, ssa.OpcodeFmul, ssa.OpcodeFdiv: + m.lowerXmmRmR(instr) + case ssa.OpcodeFabs: + m.lowerFabsFneg(instr) + case ssa.OpcodeFneg: + m.lowerFabsFneg(instr) + case ssa.OpcodeSqrt: + m.lowerSqrt(instr) case ssa.OpcodeUndefined: m.insert(m.allocateInstr().asUD2()) case ssa.OpcodeExitWithCode: @@ -682,6 +690,116 @@ func (m *machine) lowerShiftR(si *ssa.Instruction, op shiftROp) { m.copyTo(tmpDst, rd) } +func (m *machine) lowerXmmRmR(instr *ssa.Instruction) { + x, y := instr.Arg2() + if !x.Type().IsFloat() { + panic("BUG?") + } + _64 := x.Type().Bits() == 64 + + var op sseOpcode + if _64 { + switch instr.Opcode() { + case ssa.OpcodeFadd: + op = sseOpcodeAddsd + case ssa.OpcodeFsub: + op = sseOpcodeSubsd + case ssa.OpcodeFmul: + op = sseOpcodeMulsd + case ssa.OpcodeFdiv: + op = sseOpcodeDivsd + default: + panic("BUG") + } + } else { + switch instr.Opcode() { + case ssa.OpcodeFadd: + op = sseOpcodeAddss + case ssa.OpcodeFsub: + op = sseOpcodeSubss + case ssa.OpcodeFmul: + op = sseOpcodeMulss + case ssa.OpcodeFdiv: + op = sseOpcodeDivss + default: + panic("BUG") + } + } + + xDef, yDef := m.c.ValueDefinition(x), m.c.ValueDefinition(y) + rn := m.getOperand_Mem_Reg(yDef) + rm := m.getOperand_Reg(xDef) + rd := m.c.VRegOf(instr.Return()) + + // rm is being overwritten, so we first copy its value to a temp register, + // in case it is referenced again later. + tmp := m.copyToTmp(rm.r) + + xmm := m.allocateInstr().asXmmRmR(op, rn, tmp) + m.insert(xmm) + + m.copyTo(tmp, rd) +} + +func (m *machine) lowerSqrt(instr *ssa.Instruction) { + x := instr.Arg() + if !x.Type().IsFloat() { + panic("BUG") + } + _64 := x.Type().Bits() == 64 + var op sseOpcode + if _64 { + op = sseOpcodeSqrtsd + } else { + op = sseOpcodeSqrtss + } + + xDef := m.c.ValueDefinition(x) + rm := m.getOperand_Mem_Reg(xDef) + rd := m.c.VRegOf(instr.Return()) + + xmm := m.allocateInstr().asXmmUnaryRmR(op, rm, rd) + m.insert(xmm) +} + +func (m *machine) lowerFabsFneg(instr *ssa.Instruction) { + x := instr.Arg() + if !x.Type().IsFloat() { + panic("BUG") + } + _64 := x.Type().Bits() == 64 + var op sseOpcode + var mask uint64 + if _64 { + switch instr.Opcode() { + case ssa.OpcodeFabs: + mask, op = 0x7fffffffffffffff, sseOpcodeAndpd + case ssa.OpcodeFneg: + mask, op = 0x8000000000000000, sseOpcodeXorpd + } + } else { + switch instr.Opcode() { + case ssa.OpcodeFabs: + mask, op = 0x7fffffff, sseOpcodeAndps + case ssa.OpcodeFneg: + mask, op = 0x80000000, sseOpcodeXorps + } + } + + tmp := m.c.AllocateVReg(x.Type()) + + xDef := m.c.ValueDefinition(x) + rm := m.getOperand_Reg(xDef) + rd := m.c.VRegOf(instr.Return()) + + m.lowerFconst(tmp, mask, _64) + + xmm := m.allocateInstr().asXmmRmR(op, rm, tmp) + m.insert(xmm) + + m.copyTo(tmp, rd) +} + func (m *machine) lowerStore(si *ssa.Instruction) { value, ptr, offset, storeSizeInBits := si.StoreData() rm := m.getOperand_Reg(m.c.ValueDefinition(value)) diff --git a/internal/engine/wazevo/backend/isa/amd64/machine_test.go b/internal/engine/wazevo/backend/isa/amd64/machine_test.go index 8bbac9624b..c9e28f75f5 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine_test.go @@ -344,7 +344,7 @@ L2: } } -func Test_machine_lowerCtz(t *testing.T) { +func TestMachine_lowerCtz(t *testing.T) { for _, tc := range []struct { name string setup func(*mockCompiler, ssa.Builder, *machine) *backend.SSAValueDefinition diff --git a/internal/engine/wazevo/e2e_test.go b/internal/engine/wazevo/e2e_test.go index 6ee05d9fd6..fd5de3ad40 100644 --- a/internal/engine/wazevo/e2e_test.go +++ b/internal/engine/wazevo/e2e_test.go @@ -151,6 +151,35 @@ func TestE2E(t *testing.T) { }, }, }, + { + name: "float_arithm", m: testcases.FloatArithm.Module, + calls: []callCase{ + { + params: []uint64{ + math.Float64bits(25), math.Float64bits(5), uint64(math.Float32bits(25)), uint64(math.Float32bits(5)), + }, + expResults: []uint64{ + math.Float64bits(-25), + math.Float64bits(25), + + math.Float64bits(5), + math.Float64bits(30), + math.Float64bits(20), + math.Float64bits(125), + math.Float64bits(5), + + uint64(math.Float32bits(-25)), + uint64(math.Float32bits(25)), + + uint64(math.Float32bits(5)), + uint64(math.Float32bits(30)), + uint64(math.Float32bits(20)), + uint64(math.Float32bits(125)), + uint64(math.Float32bits(5)), + }, + }, + }, + }, { name: "fibonacci_recursive", m: testcases.FibonacciRecursive.Module, calls: []callCase{ diff --git a/internal/engine/wazevo/ssa/type.go b/internal/engine/wazevo/ssa/type.go index 4e320db36a..e8c8cd9de3 100644 --- a/internal/engine/wazevo/ssa/type.go +++ b/internal/engine/wazevo/ssa/type.go @@ -48,6 +48,11 @@ func (t Type) IsInt() bool { return t == TypeI32 || t == TypeI64 } +// IsFloat returns true if the type is a floating point type. +func (t Type) IsFloat() bool { + return t == TypeF32 || t == TypeF64 +} + // Bits returns the number of bits required to represent the type. func (t Type) Bits() byte { switch t { diff --git a/internal/engine/wazevo/testcases/testcases.go b/internal/engine/wazevo/testcases/testcases.go index 6ec759c259..0a8d304014 100644 --- a/internal/engine/wazevo/testcases/testcases.go +++ b/internal/engine/wazevo/testcases/testcases.go @@ -81,7 +81,7 @@ var ( }, nil), } ArithmReturn = TestCase{ - Name: "add_sub_params_return", + Name: "arithm return", Module: SingleFunctionModule( wasm.FunctionType{ Params: []wasm.ValueType{i32, i32, i32, i64, i64, i64}, @@ -989,6 +989,68 @@ var ( wasm.OpcodeEnd, }, []wasm.ValueType{}), } + FloatArithm = TestCase{ + Name: "float_arithm", + Module: SingleFunctionModule(wasm.FunctionType{ + Params: []wasm.ValueType{f64, f64, f32, f32}, + Results: []wasm.ValueType{f64, f64, f64, f64, f64, f64, f64, f32, f32, f32, f32, f32, f32, f32}, + }, []byte{ + wasm.OpcodeLocalGet, 0, + wasm.OpcodeF64Neg, + + wasm.OpcodeLocalGet, 0, + wasm.OpcodeF64Neg, + wasm.OpcodeF64Abs, + + wasm.OpcodeLocalGet, 0, + wasm.OpcodeF64Sqrt, + + wasm.OpcodeLocalGet, 0, + wasm.OpcodeLocalGet, 1, + wasm.OpcodeF64Add, + + wasm.OpcodeLocalGet, 0, + wasm.OpcodeLocalGet, 1, + wasm.OpcodeF64Sub, + + wasm.OpcodeLocalGet, 0, + wasm.OpcodeLocalGet, 1, + wasm.OpcodeF64Mul, + + wasm.OpcodeLocalGet, 0, + wasm.OpcodeLocalGet, 1, + wasm.OpcodeF64Div, + + // 32-bit floats. + wasm.OpcodeLocalGet, 2, + wasm.OpcodeF32Neg, + + wasm.OpcodeLocalGet, 2, + wasm.OpcodeF32Neg, + wasm.OpcodeF32Abs, + + wasm.OpcodeLocalGet, 2, + wasm.OpcodeF32Sqrt, + + wasm.OpcodeLocalGet, 2, + wasm.OpcodeLocalGet, 3, + wasm.OpcodeF32Add, + + wasm.OpcodeLocalGet, 2, + wasm.OpcodeLocalGet, 3, + wasm.OpcodeF32Sub, + + wasm.OpcodeLocalGet, 2, + wasm.OpcodeLocalGet, 3, + wasm.OpcodeF32Mul, + + wasm.OpcodeLocalGet, 2, + wasm.OpcodeLocalGet, 3, + wasm.OpcodeF32Div, + + wasm.OpcodeEnd, + }, []wasm.ValueType{}), + } FloatConversions = TestCase{ Name: "float_conversions", Module: SingleFunctionModule(wasm.FunctionType{