From c9de71dd69ba9049ab8dea77f994ed5b783f2878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Tempel?= Date: Tue, 30 Apr 2024 21:24:23 +0200 Subject: [PATCH] Fix various inaccuracies in the RISC-V lifter (#64) * Fix signedness handling in comparison instructions * Fix extraction of shift value for immediate instructions The shift value must be treated as an unsigned integer otherwise a shift value such as '0b11111' would be interpreted as -1 instead of 31. * Fix lifting of shift instructions For arithmetic shift instructions, use the .sra() method provided by VexValue. For R-type shift instructions, extract the shift amount by reading the lower 5 bits of the register (previously the shift amount was incorrectly extracted from the register index, not the register value). * Fix zero- and sign-extension of load instructions --- angr_platforms/risc_v/instrs_riscv/i_instr.py | 21 +++++++------------ .../instrs_riscv/instruction_patterns.py | 2 +- angr_platforms/risc_v/instrs_riscv/r_instr.py | 18 ++++++---------- 3 files changed, 15 insertions(+), 26 deletions(-) diff --git a/angr_platforms/risc_v/instrs_riscv/i_instr.py b/angr_platforms/risc_v/instrs_riscv/i_instr.py index 7a640c3..545dd40 100644 --- a/angr_platforms/risc_v/instrs_riscv/i_instr.py +++ b/angr_platforms/risc_v/instrs_riscv/i_instr.py @@ -82,7 +82,8 @@ def extra_constraints(self, data, bitstream): return data def compute_result(self, src1, _): - return (~((~src1) >> self.get_shift_amount())) & self.constant(0xffffffff, Type.int_32) + shftamnt = self.get_shift_amount() + return src1.sar(shftamnt).cast_to(Type.int_32) class Instruction_SLTI(I_Instruction): @@ -93,10 +94,7 @@ class Instruction_SLTI(I_Instruction): # TODO: ISA manual mentions sign extension, check if properly implemented def compute_result(self, src1, imm): - src1.is_signed = True - imm.is_signed = True - val = 1 if src1.signed < imm.signed else 0 - return self.constant(val, Type.int_32) + return (src1.signed < imm.signed).ite(1, 0) class Instruction_SLTIU(I_Instruction): @@ -105,10 +103,7 @@ class Instruction_SLTIU(I_Instruction): name = 'SLTIU' def compute_result(self, src1, imm): - src1.is_signed = False - imm.is_signed = False - val = 1 if src1 < imm else 0 - return self.constant(val, Type.int_32) + return (src1 < imm).ite(1, 0) class Instruction_LB(I_Instruction): func3='000' @@ -117,7 +112,7 @@ class Instruction_LB(I_Instruction): def compute_result(self, src, imm): addr = src + imm.signed - value = self.load(addr, Type.int_8).cast_to(Type.int_32) + value = self.load(addr, Type.int_8).widen_signed(Type.int_32) return value.signed class Instruction_LH(I_Instruction): @@ -127,7 +122,7 @@ class Instruction_LH(I_Instruction): def compute_result(self, src, imm): addr = src + imm - value = self.load(addr, Type.int_16).cast_to(Type.int_32) + value = self.load(addr, Type.int_16).widen_signed(Type.int_32) return value.signed class Instruction_LW(I_Instruction): @@ -148,7 +143,7 @@ class Instruction_LBU(I_Instruction): def compute_result(self, src, imm): addr = src + imm.signed - return self.load(addr, Type.int_8).cast_to(Type.int_32) + return self.load(addr, Type.int_8).widen_unsigned(Type.int_32) class Instruction_LHU(I_Instruction): func3='101' @@ -157,7 +152,7 @@ class Instruction_LHU(I_Instruction): def compute_result(self, src, imm): addr= src+imm.signed - return self.load(addr, Type.int_16).cast_to(Type.int_32) + return self.load(addr, Type.int_16).widen_unsigned(Type.int_32) class Instruction_JALR(I_Instruction): diff --git a/angr_platforms/risc_v/instrs_riscv/instruction_patterns.py b/angr_platforms/risc_v/instrs_riscv/instruction_patterns.py index 3cd866c..a580948 100644 --- a/angr_platforms/risc_v/instrs_riscv/instruction_patterns.py +++ b/angr_platforms/risc_v/instrs_riscv/instruction_patterns.py @@ -94,7 +94,7 @@ def get_imm(self): return self.constant(data, Type.int_32) def get_shift_amount(self): - num = BitArray(bin=self.data['I']).int + num = BitArray(bin=self.data['I']).uint return self.constant(num, Type.int_8) def get_optional_func7(self): diff --git a/angr_platforms/risc_v/instrs_riscv/r_instr.py b/angr_platforms/risc_v/instrs_riscv/r_instr.py index 76a467c..5d47958 100644 --- a/angr_platforms/risc_v/instrs_riscv/r_instr.py +++ b/angr_platforms/risc_v/instrs_riscv/r_instr.py @@ -59,7 +59,7 @@ class Instruction_SLL(R_Instruction): name = 'SLL' def compute_result(self, src1, src2): - shftamnt = self.get(int(self.data['S'], 2), Type.int_8) + shftamnt = src2.narrow_low(Type.int_5).cast_to(Type.int_8) return (src1 << shftamnt) & self.constant(0xffffffff, Type.int_32) @@ -70,7 +70,7 @@ class Instruction_SRL(R_Instruction): name = 'SRL' def compute_result(self, src1, src2): - shftamnt = self.get(int(self.data['S'], 2), Type.int_8) + shftamnt = src2.narrow_low(Type.int_5).cast_to(Type.int_8) return (src1 >> shftamnt) & self.constant(0xffffffff, Type.int_32) # Arithmetic shift is not easily mapped, so leaving this as an TODO @@ -83,8 +83,8 @@ class Instruction_SRA(R_Instruction): name = 'SRA' def compute_result(self, src1, src2): - shftamnt = self.get(int(self.data['S'], 2), Type.int_8) - return (~((~src1) >> shftamnt)) & self.constant(0xffffffff, Type.int_32) + shftamnt = src2.narrow_low(Type.int_5).cast_to(Type.int_8) + return src1.sar(shftamnt).cast_to(Type.int_32) class Instruction_SLT(R_Instruction): @@ -94,10 +94,7 @@ class Instruction_SLT(R_Instruction): name='SLT' def compute_result(self, src1, src2): - src1.is_signed = True - src2.is_signed = True - val = 1 if src1 < src2 else 0 - return self.constant(val, Type.int_32) + return (src1.signed < src2.signed).ite(1, 0) class Instruction_SLTU(R_Instruction): @@ -107,10 +104,7 @@ class Instruction_SLTU(R_Instruction): name = 'SLTU' def compute_result(self, src1, src2): - src1.is_signed = False - src1.is_signed = False - val = 1 if src1 < src2 else 0 - return self.constant(val, Type.int_32) + return (src1 < src2).ite(1, 0) class Instruction_MUL(R_Instruction):