Skip to content

Commit

Permalink
fix!: Partially reimplement StringsBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
N3rdL0rd committed Dec 24, 2024
1 parent 723cbb2 commit 217525f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 47 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ hl/
*.json
build*.bat
pypi_token.txt
testenv
testenv
.crashlink_debug
63 changes: 32 additions & 31 deletions crashlink/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,50 +394,48 @@ def __init__(self) -> None:
self.length.length = 4
self.value: List[str] = []
self.lengths: List[int] = []
self.embedded_lengths: List[VarInt] = []
self.lengths: List[VarInt] = []

def deserialise(self, f: BinaryIO | BytesIO) -> "StringsBlock":
def deserialise(self, f: BinaryIO | BytesIO, nstrings: int) -> "StringsBlock":
self.length.deserialise(f, length=4)
strings_size = self.length.value
dbg_print(f"Found {fmt_bytes(strings_size)} of strings")
strings_data: bytes = f.read(strings_size)

index = 0
while index < strings_size:
string_length = 0
while index + string_length < strings_size and strings_data[index + string_length] != 0:
string_length += 1

if index + string_length >= strings_size:
raise MalformedBytecode("Invalid string: no null terminator found")

string = strings_data[index : index + string_length].decode("utf-8", errors="surrogateescape")
self.value.append(string)
self.lengths.append(string_length)

index += string_length + 1 # Skip the null terminator

for _ in self.value:
self.embedded_lengths.append(VarInt().deserialise(f))

return self
size = self.length.value
sdata: bytes = f.read(size)
strings: List[str] = []
lengths: List[VarInt] = []
curpos = 0
for _ in range(nstrings):
sz = VarInt().deserialise(f)
if curpos + sz.value >= size:
raise ValueError("Invalid string")

str_value = sdata[curpos:curpos + sz.value]
if curpos + sz.value < size and sdata[curpos + sz.value] != 0:
raise ValueError("Invalid string")

strings.append(str_value.decode('utf-8', errors="surrogateescape"))
lengths.append(sz)

curpos += sz.value + 1 # +1 for null terminator
self.value = strings
self.lengths = lengths

def serialise(self) -> bytes:
strings_data = b""
for string in self.value:
strings_data += string.encode("utf-8", errors="surrogateescape") + b"\x00"
self.length.value = len(strings_data)
self.lengths = [len(string) for string in self.value]
self.embedded_lengths = [VarInt(length) for length in self.lengths]
self.lengths = [VarInt(len(string)) for string in self.value]
return b"".join(
[
self.length.serialise(),
strings_data,
b"".join([i.serialise() for i in self.embedded_lengths]),
b"".join([i.serialise() for i in self.lengths]),
]
)




class BytesBlock(Serialisable):
"""
Block of bytes in the bytecode. Contains a list of byte strings and their lengths.
Expand Down Expand Up @@ -1425,7 +1423,7 @@ def deserialise(self, f: BinaryIO | BytesIO, search_magic: bool = True) -> "Byte

dbg_print(f"Strings section starts at {tell(f)}")
self.track_section(f, "strings")
self.strings.deserialise(f)
self.strings.deserialise(f, self.nstrings.value)
dbg_print(f"Strings section ends at {tell(f)}")
assert self.nstrings.value == len(self.strings.value), "nstrings and len of strings don't match!"

Expand All @@ -1442,7 +1440,7 @@ def deserialise(self, f: BinaryIO | BytesIO, search_magic: bool = True) -> "Byte
self.ndebugfiles.deserialise(f)
dbg_print(f"Number of debug files: {self.ndebugfiles.value}")
self.track_section(f, "debugfiles")
self.debugfiles.deserialise(f)
self.debugfiles.deserialise(f, self.ndebugfiles.value)
else:
self.ndebugfiles = None
self.debugfiles = None
Expand Down Expand Up @@ -1546,7 +1544,10 @@ def const_str(self, gindex: int) -> str:
obj_fields = obj.resolve_fields(self)
if len(obj_fields) != 2:
raise ValueError(f"Global {gindex} seems malformed!")
return self.initialized_globals[gindex][obj_fields[0].name.resolve(self)]
res = self.initialized_globals[gindex][obj_fields[0].name.resolve(self)]
if not isinstance(res, str):
raise TypeError(f"This should never happen!")
return res

def serialise(self, auto_set_meta: bool = True) -> bytes:
start_time = datetime.now()
Expand Down
52 changes: 38 additions & 14 deletions crashlink/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,14 @@ def graph(self, code: Bytecode) -> str:
dot.append("}")
return "\n".join(dot)

def num_incoming(self, node: CFNode) -> int:
count = 0
def predecessors(self, node: CFNode) -> List[CFNode]:
"""Get predecessors of a node"""
preds = []
for n in self.nodes:
for branch, _ in n.branches:
if branch == node:
count += 1
return count

for succ, _ in n.branches:
if succ == node:
preds.append(n)
return preds

class IRStatement(ABC):
def __init__(self, code: Bytecode):
Expand Down Expand Up @@ -605,6 +605,28 @@ def __repr__(self) -> str:
return f"<IRConditional: if {self.condition} then\n{self.true_block}\nelse\n{self.false_block}>"


class IRLoop(IRStatement):
"""Loop statement"""

def __init__(self, code: Bytecode, condition: IRExpression, body: IRBlock):
super().__init__(code)
self.condition = condition
self.body = body

def __repr__(self) -> str:
return f"<IRLoop: while {self.condition}\n{self.body}>"


class IRBreak(IRStatement):
"""Break statement"""

def __init__(self, code: Bytecode):
super().__init__(code)

def __repr__(self) -> str:
return "<IRBreak>"


class IRReturn(IRStatement):
"""Return statement"""

Expand Down Expand Up @@ -722,13 +744,15 @@ def _find_convergence(self, true_node: CFNode, false_node: CFNode, visited: Set[

return None # No convergence found

def _lift_block(self, node: CFNode, visited: Optional[Set[CFNode]] = None) -> IRBlock:
"""Lift a control flow node to an IR block"""
def _lift_block(self, node: CFNode,
visited: Optional[Set[CFNode]] = None) -> IRBlock:
if visited is None:
visited = set()

if node in visited:
return IRBlock(self.code)

if node in visited:
return IRBlock(self.code) # Return empty block for cycles
visited.add(node)

visited.add(node)
block = IRBlock(self.code)
Expand Down Expand Up @@ -799,13 +823,13 @@ def _lift_block(self, node: CFNode, visited: Optional[Set[CFNode]] = None) -> IR
# it's an empty block and that's what comes *after* the conditional branches altogether.
should_lift_t = True
should_lift_f = True
if self.cfg.num_incoming(true_branch) > 1:
if len(self.cfg.predecessors(true_branch)) > 1:
should_lift_t = False
if self.cfg.num_incoming(false_branch) > 1:
if len(self.cfg.predecessors(false_branch)) > 1:
should_lift_f = False

if not should_lift_t and not should_lift_f:
print("WARNING: Skipping conditional due to weird incoming branches.")
dbg_print("Warning: Skipping conditional due to weird incoming branches.")
continue

cond_map = {
Expand Down
9 changes: 8 additions & 1 deletion crashlink/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

from io import BytesIO
from typing import Any, BinaryIO
import os

def _is_debug() -> bool:
"""
Determine whether or not to enable debug mode.
"""
return bool(os.getenv("CRASHLINK_DEBUG", False)) or bool(os.getenv("DEBUG", False)) or os.path.exists(".crashlink_debug")

VERSION: str = "v0.0.1a"
"""
Expand All @@ -15,7 +22,7 @@
String displayed in the help message for the CLI.
"""

DEBUG: bool = False
DEBUG: bool = False or _is_debug()
"""
Whether to enable certain features meant only for development or debugging of crashlink.
"""
Expand Down

0 comments on commit 217525f

Please sign in to comment.