Skip to content

Commit

Permalink
Fix various inaccuracies in the RISC-V lifter (#64)
Browse files Browse the repository at this point in the history
* Fix signedness handling in comparison instructions

* Fix extraction of shift value for immediate instructions

The shift value must be treated as an unsigned integer otherwise
a shift value such as '0b11111' would be interpreted as -1 instead
of 31.

* Fix lifting of shift instructions

For arithmetic shift instructions, use the .sra() method provided
by VexValue. For R-type shift instructions, extract the shift amount
by reading the lower 5 bits of the register (previously the shift
amount was incorrectly extracted from the register index, not the
register value).

* Fix zero- and sign-extension of load instructions
  • Loading branch information
nmeum authored Apr 30, 2024
1 parent 694c4dd commit c9de71d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 26 deletions.
21 changes: 8 additions & 13 deletions angr_platforms/risc_v/instrs_riscv/i_instr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def extra_constraints(self, data, bitstream):
return data

def compute_result(self, src1, _):
return (~((~src1) >> self.get_shift_amount())) & self.constant(0xffffffff, Type.int_32)
shftamnt = self.get_shift_amount()
return src1.sar(shftamnt).cast_to(Type.int_32)


class Instruction_SLTI(I_Instruction):
Expand All @@ -93,10 +94,7 @@ class Instruction_SLTI(I_Instruction):

# TODO: ISA manual mentions sign extension, check if properly implemented
def compute_result(self, src1, imm):
src1.is_signed = True
imm.is_signed = True
val = 1 if src1.signed < imm.signed else 0
return self.constant(val, Type.int_32)
return (src1.signed < imm.signed).ite(1, 0)


class Instruction_SLTIU(I_Instruction):
Expand All @@ -105,10 +103,7 @@ class Instruction_SLTIU(I_Instruction):
name = 'SLTIU'

def compute_result(self, src1, imm):
src1.is_signed = False
imm.is_signed = False
val = 1 if src1 < imm else 0
return self.constant(val, Type.int_32)
return (src1 < imm).ite(1, 0)

class Instruction_LB(I_Instruction):
func3='000'
Expand All @@ -117,7 +112,7 @@ class Instruction_LB(I_Instruction):

def compute_result(self, src, imm):
addr = src + imm.signed
value = self.load(addr, Type.int_8).cast_to(Type.int_32)
value = self.load(addr, Type.int_8).widen_signed(Type.int_32)
return value.signed

class Instruction_LH(I_Instruction):
Expand All @@ -127,7 +122,7 @@ class Instruction_LH(I_Instruction):

def compute_result(self, src, imm):
addr = src + imm
value = self.load(addr, Type.int_16).cast_to(Type.int_32)
value = self.load(addr, Type.int_16).widen_signed(Type.int_32)
return value.signed

class Instruction_LW(I_Instruction):
Expand All @@ -148,7 +143,7 @@ class Instruction_LBU(I_Instruction):
def compute_result(self, src, imm):
addr = src + imm.signed

return self.load(addr, Type.int_8).cast_to(Type.int_32)
return self.load(addr, Type.int_8).widen_unsigned(Type.int_32)

class Instruction_LHU(I_Instruction):
func3='101'
Expand All @@ -157,7 +152,7 @@ class Instruction_LHU(I_Instruction):

def compute_result(self, src, imm):
addr= src+imm.signed
return self.load(addr, Type.int_16).cast_to(Type.int_32)
return self.load(addr, Type.int_16).widen_unsigned(Type.int_32)


class Instruction_JALR(I_Instruction):
Expand Down
2 changes: 1 addition & 1 deletion angr_platforms/risc_v/instrs_riscv/instruction_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_imm(self):
return self.constant(data, Type.int_32)

def get_shift_amount(self):
num = BitArray(bin=self.data['I']).int
num = BitArray(bin=self.data['I']).uint
return self.constant(num, Type.int_8)

def get_optional_func7(self):
Expand Down
18 changes: 6 additions & 12 deletions angr_platforms/risc_v/instrs_riscv/r_instr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Instruction_SLL(R_Instruction):
name = 'SLL'

def compute_result(self, src1, src2):
shftamnt = self.get(int(self.data['S'], 2), Type.int_8)
shftamnt = src2.narrow_low(Type.int_5).cast_to(Type.int_8)
return (src1 << shftamnt) & self.constant(0xffffffff, Type.int_32)


Expand All @@ -70,7 +70,7 @@ class Instruction_SRL(R_Instruction):
name = 'SRL'

def compute_result(self, src1, src2):
shftamnt = self.get(int(self.data['S'], 2), Type.int_8)
shftamnt = src2.narrow_low(Type.int_5).cast_to(Type.int_8)
return (src1 >> shftamnt) & self.constant(0xffffffff, Type.int_32)

# Arithmetic shift is not easily mapped, so leaving this as an TODO
Expand All @@ -83,8 +83,8 @@ class Instruction_SRA(R_Instruction):
name = 'SRA'

def compute_result(self, src1, src2):
shftamnt = self.get(int(self.data['S'], 2), Type.int_8)
return (~((~src1) >> shftamnt)) & self.constant(0xffffffff, Type.int_32)
shftamnt = src2.narrow_low(Type.int_5).cast_to(Type.int_8)
return src1.sar(shftamnt).cast_to(Type.int_32)


class Instruction_SLT(R_Instruction):
Expand All @@ -94,10 +94,7 @@ class Instruction_SLT(R_Instruction):
name='SLT'

def compute_result(self, src1, src2):
src1.is_signed = True
src2.is_signed = True
val = 1 if src1 < src2 else 0
return self.constant(val, Type.int_32)
return (src1.signed < src2.signed).ite(1, 0)


class Instruction_SLTU(R_Instruction):
Expand All @@ -107,10 +104,7 @@ class Instruction_SLTU(R_Instruction):
name = 'SLTU'

def compute_result(self, src1, src2):
src1.is_signed = False
src1.is_signed = False
val = 1 if src1 < src2 else 0
return self.constant(val, Type.int_32)
return (src1 < src2).ite(1, 0)


class Instruction_MUL(R_Instruction):
Expand Down

0 comments on commit c9de71d

Please sign in to comment.