Skip to content

Commit

Permalink
Feat: Support fine-grained contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
mahaloz committed Jul 3, 2024
1 parent 6aa6a56 commit 384638b
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 65 deletions.
10 changes: 5 additions & 5 deletions examples/change_watcher_plugin/bs_change_watcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def create_plugin(*args, **kwargs):

from libbs.api import DecompilerInterface
from libbs.artifacts import (
FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment
FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment, Context
)

deci = DecompilerInterface.discover(
Expand All @@ -25,12 +25,12 @@ def create_plugin(*args, **kwargs):
)
# create a function to print a string in the decompiler console
decompiler_printer = lambda *x, **y: deci.print(f"Changed {x}")
ctx_printer = lambda *x, **y: deci.print(f"Context changed: {x}")
# register the callback for all the types we want to print
deci.artifact_write_callbacks = {
typ: [decompiler_printer] for typ in (FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment,)
deci.artifact_change_callbacks = {
typ: [decompiler_printer] for typ in (
FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment, Context
)
}
deci.gui_ctx_change_callbacks.append(ctx_printer)

# register a menu to open when you right click on the psuedocode view
deci.gui_register_ctx_menu(
Expand Down
2 changes: 1 addition & 1 deletion libbs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.9.2"
__version__ = "1.10.0"

import logging
logging.getLogger("libbs").addHandler(logging.NullHandler())
Expand Down
58 changes: 34 additions & 24 deletions libbs/api/decompiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Artifact,
Function, FunctionHeader, StackVariable,
Comment, GlobalVariable, Patch,
Enum, Struct, FunctionArgument
Enum, Struct, FunctionArgument, Context
)
from libbs.decompilers import SUPPORTED_DECOMPILERS, ANGR_DECOMPILER, \
BINJA_DECOMPILER, IDA_DECOMPILER, GHIDRA_DECOMPILER
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
gui_init_args: Optional[Tuple] = None,
gui_init_kwargs: Optional[Dict] = None,
# [artifact_class] = list(callback_func)
artifact_write_callbacks: Optional[Dict[Type[Artifact], List[Callable]]] = None,
artifact_change_callbacks: Optional[Dict[Type[Artifact], List[Callable]]] = None,
thread_artifact_callbacks: bool = True,
):
self.name = name
Expand All @@ -67,11 +67,9 @@ def __init__(
self.qt_version = qt_version
self._error_on_artifact_duplicates = error_on_artifact_duplicates

# GUI things
self.headless = headless
self._headless_dec_path = Path(headless_dec_path) if headless_dec_path else None
self._binary_path = Path(binary_path) if binary_path else None

self._init_plugin = init_plugin
self._unparsed_gui_ctx_actions = gui_ctx_menu_actions or {}
# (category, name, action_string, callback_func)
Expand All @@ -84,7 +82,7 @@ def __init__(
self.artifact_write_lock = threading.Lock()

# callback functions, keyed by Artifact class
self.artifact_write_callbacks = artifact_write_callbacks or defaultdict(list)
self.artifact_change_callbacks = artifact_change_callbacks or defaultdict(list)
self._thread_artifact_callbacks = thread_artifact_callbacks

# artifact dict aliases:
Expand Down Expand Up @@ -155,11 +153,11 @@ def shutdown(self):
# GUI API
#

def gui_active_context(self) -> libbs.artifacts.Function:
def gui_active_context(self) -> Optional[libbs.artifacts.Context]:
"""
Returns an libbs Function. Currently only functions are supported as current contexts.
This function will be called very frequently, so its important that its implementation is fast
and can be done many times in the decompiler.
Returns the active location that the user is currently _clicked_ on in the decompiler.
This is returned as a Context object, which can address and screen naming information dependent
on the decompilers exposed data.
"""
raise NotImplementedError

Expand Down Expand Up @@ -245,6 +243,18 @@ def binary_path(self) -> Optional[str]:
"""
return self._binary_path

def fast_get_function(self, func_addr) -> Optional[Function]:
"""
Attempts to get a light version of the Function at func_addr.
This function implements special logic to be faster than grabbing all light-functions, or grabbing
a decompiled function. Use this API in the case where you may need to get a single functions info
many times in a loop.
@param func_addr:
@return:
"""
raise NotImplementedError

def get_func_size(self, func_addr) -> int:
"""
Returns the size of a function
Expand Down Expand Up @@ -491,20 +501,20 @@ def _set_function_header(self, fheader: FunctionHeader, **kwargs) -> bool:
# lift it ONCE inside this function. Each one will return the lifted form, for easier overriding.
#

def gui_context_changed(self, view_name: str, func: Optional[Function] = None, addr: Optional[int] = None, **kwargs):
if not self._watchers_started:
return None, None, None

lifted_func = self.art_lifter.lift(func) if func is not None else None
lifted_addr = self.art_lifter.lift_addr(addr) if addr is not None else None
for callback_func in self.gui_ctx_change_callbacks:
threading.Thread(target=callback_func, args=(view_name, lifted_func, lifted_addr), kwargs=kwargs, daemon=True).start()
def gui_context_changed(self, ctx: Context, **kwargs) -> libbs.artifacts.Context:
# XXX: should this be lifted?
for callback_func in self.artifact_change_callbacks[Context]:
args = (ctx,)
if self._thread_artifact_callbacks:
threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start()
else:
callback_func(*args, **kwargs)

return lifted_func, lifted_addr, view_name
return ctx

def function_header_changed(self, fheader: FunctionHeader, **kwargs) -> FunctionHeader:
lifted_fheader = self.art_lifter.lift(fheader)
for callback_func in self.artifact_write_callbacks[FunctionHeader]:
for callback_func in self.artifact_change_callbacks[FunctionHeader]:
args = (lifted_fheader,)
if self._thread_artifact_callbacks:
threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start()
Expand All @@ -515,7 +525,7 @@ def function_header_changed(self, fheader: FunctionHeader, **kwargs) -> Function

def stack_variable_changed(self, svar: StackVariable, **kwargs) -> StackVariable:
lifted_svar = self.art_lifter.lift(svar)
for callback_func in self.artifact_write_callbacks[StackVariable]:
for callback_func in self.artifact_change_callbacks[StackVariable]:
args = (lifted_svar,)
if self._thread_artifact_callbacks:
threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start()
Expand All @@ -527,7 +537,7 @@ def stack_variable_changed(self, svar: StackVariable, **kwargs) -> StackVariable
def comment_changed(self, comment: Comment, deleted=False, **kwargs) -> Comment:
kwargs["deleted"] = deleted
lifted_cmt = self.art_lifter.lift(comment)
for callback_func in self.artifact_write_callbacks[Comment]:
for callback_func in self.artifact_change_callbacks[Comment]:
args = (lifted_cmt,)
if self._thread_artifact_callbacks:
threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start()
Expand All @@ -539,7 +549,7 @@ def comment_changed(self, comment: Comment, deleted=False, **kwargs) -> Comment:
def struct_changed(self, struct: Struct, deleted=False, **kwargs) -> Struct:
kwargs["deleted"] = deleted
lifted_struct = self.art_lifter.lift(struct)
for callback_func in self.artifact_write_callbacks[Struct]:
for callback_func in self.artifact_change_callbacks[Struct]:
args = (lifted_struct,)
if self._thread_artifact_callbacks:
threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start()
Expand All @@ -551,7 +561,7 @@ def struct_changed(self, struct: Struct, deleted=False, **kwargs) -> Struct:
def enum_changed(self, enum: Enum, deleted=False, **kwargs) -> Enum:
kwargs["deleted"] = deleted
lifted_enum = self.art_lifter.lift(enum)
for callback_func in self.artifact_write_callbacks[Enum]:
for callback_func in self.artifact_change_callbacks[Enum]:
args = (lifted_enum,)
if self._thread_artifact_callbacks:
threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start()
Expand All @@ -562,7 +572,7 @@ def enum_changed(self, enum: Enum, deleted=False, **kwargs) -> Enum:

def global_variable_changed(self, gvar: GlobalVariable, **kwargs) -> GlobalVariable:
lifted_gvar = self.art_lifter.lift(gvar)
for callback_func in self.artifact_write_callbacks[GlobalVariable]:
for callback_func in self.artifact_change_callbacks[GlobalVariable]:
args = (lifted_gvar,)
if self._thread_artifact_callbacks:
threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start()
Expand Down
11 changes: 5 additions & 6 deletions libbs/artifacts/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

from .artifact import Artifact
from .func import Function


class Context(Artifact):
Expand All @@ -11,14 +10,14 @@ class Context(Artifact):
"screen_name"
)

def __init__(self, addr: int = None, func: Optional[Function] = None, screen_name: str = None, **kwargs):
def __init__(self, addr: int = None, func_addr: Optional[int] = None, screen_name: str = None, **kwargs):
self.addr: Optional[int] = addr
self.func_addr: Optional[int] = func_addr
self.screen_name: str = screen_name
super().__init__(**kwargs)
self.addr = addr
self.func_addr = func
self.screen_name = screen_name

def __str__(self):
post_text = f" name={self.screen_name}" if self.screen_name else ""
post_text = f" screen={self.screen_name}" if self.screen_name else ""
if self.func_addr is not None:
post_text = f"@{hex(self.func_addr)}" + post_text
if self.addr is not None:
Expand Down
8 changes: 8 additions & 0 deletions libbs/decompilers/binja/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ def binary_path(self) -> Optional[str]:
except Exception:
return None

def fast_get_function(self, func_addr) -> Optional[Function]:
func_addr = self.art_lifter.lower_addr(func_addr)
func = self.bv.get_function_at(func_addr)
if not func:
return None

return self.bn_func_to_bs(func)

def get_func_size(self, func_addr) -> int:
func_addr = self.art_lifter.lower_addr(func_addr)
func = self.bv.get_function_at(func_addr)
Expand Down
17 changes: 17 additions & 0 deletions libbs/decompilers/ghidra/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ def gui_goto(self, func_addr) -> None:
# Mandatory API
#

def fast_get_function(self, func_addr) -> Optional[Function]:
lowered_addr = self.art_lifter.lower_addr(func_addr)
gfuncs = self.__fast_function(lowered_addr)
gfunc = gfuncs[0] if gfuncs else None
if gfunc is None:
_l.error(f"Func does not exist at {lowered_addr}")

bs_func = self._gfunc_to_bsfunc(gfunc)
lifted_func = self.art_lifter.lift(bs_func)
return lifted_func

@property
def binary_base_addr(self) -> int:
# TODO: this is a hack for a dumb cache, and can cause bugs, but good enough for now:
Expand Down Expand Up @@ -810,6 +821,12 @@ def isinstance(obj, cls):
# Internal functions that are very dangerous
#

@ui_remote_eval
def __fast_function(self, lowered_addr: int) -> List["GhidraFunction"]:
return [
self.currentProgram.getFunctionManager().getFunctionContaining(self.flat_api.toAddr(lowered_addr))
]

@ui_remote_eval
def __functions(self) -> List[Tuple[int, str, int]]:
return [
Expand Down
6 changes: 6 additions & 0 deletions libbs/decompilers/ida/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def _requires_decompilation(*args, **kwargs):
return _requires_decompilation


def get_func_ret_type(ea):
tinfo = ida_typeinf.tinfo_t()
got_info = idaapi.get_tinfo(tinfo, ea)
return tinfo.get_rettype() if got_info else None


def set_func_ret_type(ea, return_type_str):
tinfo = ida_typeinf.tinfo_t()
if not idaapi.get_tinfo(tinfo, ea):
Expand Down
35 changes: 26 additions & 9 deletions libbs/decompilers/ida/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from . import compat
from libbs.artifacts import (
FunctionHeader, StackVariable,
Comment, GlobalVariable, Enum, Struct
Comment, GlobalVariable, Enum, Struct, Context
)


Expand All @@ -56,6 +56,16 @@
IDA_EXTRA_CMT = "extra"
IDA_CMT_TYPES = {IDA_CMT_CMT, IDA_EXTRA_CMT, IDA_RANGE_CMT}

FORM_TYPE_TO_NAME = {
idaapi.BWN_PSEUDOCODE: "decompilation",
idaapi.BWN_DISASM: "disassembly",
idaapi.BWN_FUNCS: "functions",
idaapi.BWN_STRUCTS: "structs",
idaapi.BWN_ENUMS: "enums",
}

FUNC_FORMS = {"decompilation", "disassembly"}


def while_should_watch(func):
@functools.wraps(func)
Expand Down Expand Up @@ -90,18 +100,25 @@ def __init__(self, interface: "IDAInterface"):
super(ScreenHook, self).__init__()

def view_click(self, view, event):
form_type = idaapi.get_widget_type(view)
decomp_view = idaapi.get_widget_vdui(view)
if not form_type:
if not self.interface._artifact_watchers_started:
return

is_disass_view = form_type == idaapi.BWN_DISASM
ea = idc.get_screen_ea()
if ea is None:
form_type = idaapi.get_widget_type(view)
#decomp_view = idaapi.get_widget_vdui(view)
if not form_type:
return

self.interface.update_active_context(ea)
self.interface.gui_context_changed("disassembly" if is_disass_view else "decompilation", addr=ea)
view_name = FORM_TYPE_TO_NAME.get(form_type, "unknown")
ctx = Context(screen_name=view_name)
if view_name in FUNC_FORMS:
ctx.addr = idaapi.get_screen_ea()
func = idaapi.get_func(ctx.addr)
if func is not None:
ctx.func_addr = func.start_ea

ctx = self.interface.art_lifter.lift(ctx)
self.interface._gui_active_context = ctx
self.interface.gui_context_changed(ctx)


class IDAHotkeyHook(ida_kernwin.UI_Hooks):
Expand Down
Loading

0 comments on commit 384638b

Please sign in to comment.