diff --git a/crashlink/__main__.py b/crashlink/__main__.py index 6c5a63c..521659a 100644 --- a/crashlink/__main__.py +++ b/crashlink/__main__.py @@ -3,270 +3,261 @@ """ import argparse +import inspect import os import platform import subprocess import sys import tempfile import webbrowser -from collections.abc import Callable -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple -from . import decomp, disasm +from . import decomp, disasm from .core import Bytecode, Native from .globals import VERSION +class Commands: + """Container class for all CLI commands""" + + def __init__(self, code: Bytecode): + self.code = code + + def _format_help(self, doc: str, cmd: str) -> Tuple[str, str]: + """Formats the docstring for a command. Returns (usage, description)""" + s = doc.strip().split("`") + if len(s) == 1: + return cmd, " ".join(s) + return s[1], s[0] + + + def help(self, args: List[str]) -> None: + """Prints this help message or information on a specific command. `help (command)`""" + commands = self._get_commands() + if args: + for command in args: + if command in commands: + usage, desc = self._format_help(commands[command].__doc__, command) + print(f"{usage} - {desc}") + else: + print(f"Unknown command: {command}") + return + print("Available commands:") + for cmd, func in commands.items(): + usage, desc = self._format_help(func.__doc__, cmd) + print(f"\t{usage} - {desc}") + print("Type 'help ' for information on a specific command.") + + def exit(self, args: List[str]) -> None: + """Exit the program""" + sys.exit() + + def wiki(self, args: List[str]) -> None: + """Open the HLBC wiki in your default browser""" + webbrowser.open("https://github.com/Gui-Yom/hlbc/wiki/Bytecode-file-format") + + def opcodes(self, args: List[str]) -> None: + """Open the HLBC source to opcodes.rs in your default browser""" + webbrowser.open("https://github.com/Gui-Yom/hlbc/blob/master/crates/hlbc/src/opcodes.rs") + + def funcs(self, args: List[str]) -> None: + """List all functions in the bytecode - pass 'std' to not exclude stdlib""" + std = args and args[0] == "std" + for func in self.code.functions: + if disasm.is_std(self.code, func) and not std: + continue + print(disasm.func_header(self.code, func)) + for native in self.code.natives: + if disasm.is_std(self.code, native) and not std: + continue + print(disasm.native_header(self.code, native)) -def cmd_help(args: List[str], code: Bytecode) -> None: - """ - Help command, lists available commands from `COMMANDS`. - """ - if args: - for command in args: - if command in COMMANDS: - print(f"{command} - {COMMANDS[command][1]}") - else: - print(f"Unknown command: {command}") - return - print("Available commands:") - for cmd in COMMANDS: - print(f"\t{cmd} - {COMMANDS[cmd][1]}") - print("Type 'help ' for information on a specific command.") - - -def cmd_funcs(args: List[str], code: Bytecode) -> None: - """ - Prints all functions and natives in the bytecode. If `std` is passed as an argument, it will include stdlib functions. - """ - std = args and args[0] == "std" - for func in code.functions: - if disasm.is_std(code, func) and not std: - continue - print(disasm.func_header(code, func)) - for native in code.natives: - if disasm.is_std(code, native) and not std: - continue - print(disasm.native_header(code, native)) - - -def cmd_entry(args: List[str], code: Bytecode) -> None: - """ - Prints the entrypoint of the bytecode. - """ - entry = code.entrypoint.resolve(code) - print(" Entrypoint:", disasm.func_header(code, entry)) - + def entry(self, args: List[str]) -> None: + """Prints the entrypoint of the bytecode.""" + entry = self.code.entrypoint.resolve(self.code) + print(" Entrypoint:", disasm.func_header(self.code, entry)) -def cmd_fn(args: List[str], code: Bytecode) -> None: - """ - Disassembles a function to pseudocode by findex. - """ - if not args: - print("Usage: fn ") - return - try: - index = int(args[0]) - except ValueError: - print("Invalid index.") - return - for func in code.functions: - if func.findex.value == index: - print(disasm.func(code, func)) + def fn(self, args: List[str]) -> None: + """Disassembles a function to pseudocode by findex. `fn `""" + if len(args) == 0: + print("Usage: fn ") return - for native in code.natives: - if native.findex.value == index: - print(disasm.native_header(code, native)) + try: + index = int(args[0]) + except ValueError: + print("Invalid index.") return - print("Function not found.") - - -def cmd_cfg(args: List[str], code: Bytecode) -> None: - """ - Renders a control flow graph for a given findex and attempts to open it in the default image viewer.s - """ - if not args: - print("Usage: cfg ") - return - try: - index = int(args[0]) - except ValueError: - print("Invalid index.") - return - for func in code.functions: - if func.findex.value == index: - cfg = decomp.CFGraph(func) - print("Building control flow graph...") - cfg.build() - print("DOT:") - dot = cfg.graph(code) - print(dot) - print("Attempting to render graph...") - with tempfile.NamedTemporaryFile(suffix=".dot", delete=False) as f: - f.write(dot.encode()) - dot_file = f.name - - png_file = dot_file.replace(".dot", ".png") - try: - subprocess.run(["dot", "-Tpng", dot_file, "-o", png_file, "-Gdpi=300"], check=True) - except FileNotFoundError: - print("Graphviz not found. Install Graphviz to generate PNGs.") + for func in self.code.functions: + if func.findex.value == index: + print(disasm.func(self.code, func)) + return + for native in self.code.natives: + if native.findex.value == index: + print(disasm.native_header(self.code, native)) return + print("Function not found.") - try: - if platform.system() == "Windows": - subprocess.run(["start", png_file], shell=True) - elif platform.system() == "Darwin": - subprocess.run(["open", png_file]) - else: - subprocess.run(["xdg-open", png_file]) - os.unlink(dot_file) - except: - print(f"Control flow graph saved to {png_file}. Use your favourite image viewer to open it.") + def cfg(self, args: List[str]) -> None: + """Renders a control flow graph for a given findex and attempts to open it in the default image viewer. `cfg `""" + if len(args) == 0: + print("Usage: cfg ") return - print("Function not found.") - - -def cmd_ir(args: List[str], code: Bytecode) -> None: - if not args: - print("Usage: ir ") - try: - index = int(args[0]) - except ValueError: - print("Invalid index.") - return - for func in code.functions: - if func.findex.value == index: - ir = decomp.IRFunction(code, func) - ir.print() + try: + index = int(args[0]) + except ValueError: + print("Invalid index.") return - print("Function not found.") - - -def cmd_patch(args: List[str], code: Bytecode) -> None: - if not args: - print("Usage: patch ") - return - try: - index = int(args[0]) - except ValueError: - print("Invalid index.") - return - try: - func = code.fn(index) - except ValueError: + for func in self.code.functions: + if func.findex.value == index: + cfg = decomp.CFGraph(func) + print("Building control flow graph...") + cfg.build() + print("DOT:") + dot = cfg.graph(self.code) + print(dot) + print("Attempting to render graph...") + with tempfile.NamedTemporaryFile(suffix=".dot", delete=False) as f: + f.write(dot.encode()) + dot_file = f.name + + png_file = dot_file.replace(".dot", ".png") + try: + subprocess.run(["dot", "-Tpng", dot_file, "-o", png_file, "-Gdpi=300"], check=True) + except FileNotFoundError: + print("Graphviz not found. Install Graphviz to generate PNGs.") + return + + try: + if platform.system() == "Windows": + subprocess.run(["start", png_file], shell=True) + elif platform.system() == "Darwin": + subprocess.run(["open", png_file]) + else: + subprocess.run(["xdg-open", png_file]) + os.unlink(dot_file) + except: + print(f"Control flow graph saved to {png_file}. Use your favourite image viewer to open it.") + return print("Function not found.") - return - if isinstance(func, Native): - print("Cannot patch native.") - return - content = f"""{disasm.func(code, func)} -###### Modify the opcodes below this line. Any edits above this line will be ignored, and removing this line will cause patching to fail. ##### -{disasm.to_asm(func.ops)}""" - with tempfile.NamedTemporaryFile(suffix=".hlasm", mode="w", encoding="utf-8", delete=False) as f: - f.write(content) - file = f.name - try: - import tkinter as tk - from tkinter import scrolledtext - - def save_and_exit() -> None: - with open(file, "w", encoding="utf-8") as f: - f.write(text.get("1.0", tk.END)) - root.destroy() - - root = tk.Tk() - root.title(f"Editing function f@{index}") - text = scrolledtext.ScrolledText(root, width=200, height=50) - text.pack() - text.insert("1.0", content) - - button = tk.Button(root, text="Save and Exit", command=save_and_exit) - button.pack() - - root.mainloop() - except ImportError: - if os.name == "nt": - os.system(f'notepad "{file}"') - elif os.name == "posix": - os.system(f'nano "{file}"') - else: - print("No suitable editor found") - os.unlink(file) + def ir(self, args: List[str]) -> None: + """Prints the IR of a function in object-notation. `ir `""" + if len(args) == 0: + print("Usage: ir ") + try: + index = int(args[0]) + except ValueError: + print("Invalid index.") return - try: - with open(file, "r", encoding="utf-8") as f2: # why mypy, why??? - modified = f2.read() + for func in self.code.functions: + if func.findex.value == index: + ir = decomp.IRFunction(self.code, func) + ir.print() + return + print("Function not found.") - lines = modified.split("\n") - sep_idx = next(i for i, line in enumerate(lines) if "######" in line) - new_asm = "\n".join(lines[sep_idx + 1 :]) - new_ops = disasm.from_asm(new_asm) + def patch(self, args: List[str]) -> None: + """Patches a function's raw opcodes. `patch `""" + if len(args) == 0: + print("Usage: patch ") + return + try: + index = int(args[0]) + except ValueError: + print("Invalid index.") + return + try: + func = self.code.fn(index) + except ValueError: + print("Function not found.") + return + if isinstance(func, Native): + print("Cannot patch native.") + return + content = f"""{disasm.func(self.code, func)} - func.ops = new_ops - print(f"Function f@{index} updated successfully") +###### Modify the opcodes below this line. Any edits above this line will be ignored, and removing this line will cause patching to fail. ##### +{disasm.to_asm(func.ops)}""" + with tempfile.NamedTemporaryFile(suffix=".hlasm", mode="w", encoding="utf-8", delete=False) as f: + f.write(content) + file = f.name + try: + import tkinter as tk + from tkinter import scrolledtext + + def save_and_exit() -> None: + with open(file, "w", encoding="utf-8") as f: + f.write(text.get("1.0", tk.END)) + root.destroy() + + root = tk.Tk() + root.title(f"Editing function f@{index}") + text = scrolledtext.ScrolledText(root, width=200, height=50) + text.pack() + text.insert("1.0", content) + + button = tk.Button(root, text="Save and Exit", command=save_and_exit) + button.pack() + + root.mainloop() + except ImportError: + if os.name == "nt": + os.system(f'notepad "{file}"') + elif os.name == "posix": + os.system(f'nano "{file}"') + else: + print("No suitable editor found") + os.unlink(file) + return + try: + with open(file, "r", encoding="utf-8") as f2: # why mypy, why??? + modified = f2.read() - except Exception as e: - print(f"Failed to patch function: {e}") - finally: - os.unlink(file) + lines = modified.split("\n") + sep_idx = next(i for i, line in enumerate(lines) if "######" in line) + new_asm = "\n".join(lines[sep_idx + 1 :]) + new_ops = disasm.from_asm(new_asm) + func.ops = new_ops + print(f"Function f@{index} updated successfully") -def cmd_save(args: List[str], code: Bytecode) -> None: - if not args: - print("Usage: save ") - return - print("Serialising...") - ser = code.serialise() - print("Saving...") - with open(args[0], "wb") as f: - f.write(ser) - print("Done!") - - -# typing is ignored for lambdas because webbrowser.open returns a bool instead of None -COMMANDS: Dict[str, Tuple[Callable[[List[str], Bytecode], None], str]] = { - "exit": (lambda _, __: sys.exit(), "Exit the program"), - "help": (cmd_help, "Show this help message"), - "wiki": ( - lambda _, __: webbrowser.open("https://github.com/Gui-Yom/hlbc/wiki/Bytecode-file-format"), # type: ignore - "Open the HLBC wiki in your default browser", - ), - "opcodes": ( - lambda _, __: webbrowser.open("https://github.com/Gui-Yom/hlbc/blob/master/crates/hlbc/src/opcodes.rs"), # type: ignore - "Open the HLBC source to opcodes.rs in your default browser", - ), - "funcs": ( - cmd_funcs, - "List all functions in the bytecode - pass 'std' to not exclude stdlib", - ), - "entry": (cmd_entry, "Show the entrypoint of the bytecode"), - "fn": (cmd_fn, "Show information about a function"), - # "decomp": (cmd_decomp, "Decompile a function"), - "cfg": (cmd_cfg, "Graph the control flow graph of a function"), - "patch": (cmd_patch, "Patch a function's raw opcodes"), - "save": (cmd_save, "Save the modified bytecode to a given path"), - "ir": (cmd_ir, "Display the IR of a function in object-notation"), -} -""" -List of CLI commands. -""" + except Exception as e: + print(f"Failed to patch function: {e}") + finally: + os.unlink(file) + def save(self, args: List[str]) -> None: + """Saves the modified bytecode to a given path. `save `""" + if len(args) == 0: + print("Usage: save ") + return + print("Serialising...") + ser = self.code.serialise() + print("Saving...") + with open(args[0], "wb") as f: + f.write(ser) + print("Done!") + + def _get_commands(self) -> Dict[str, callable]: + """Get all command methods using reflection""" + return {name: func for name, func in inspect.getmembers(self, predicate=inspect.ismethod) + if not name.startswith('_')} def handle_cmd(code: Bytecode, is_hlbc: bool, cmd: str) -> None: - """ - Handles a command. - """ + """Handles a command.""" cmd_list: List[str] = cmd.split(" ") - if not is_hlbc: - for command in COMMANDS: - if cmd_list[0] == command: - COMMANDS[command][0](cmd_list[1:], code) - return - else: + if not cmd_list[0]: + return + + if is_hlbc: raise NotImplementedError("HLBC compatibility mode is not yet implemented.") - print("Unknown command.") - + + commands = Commands(code) + available_commands = commands._get_commands() + + if cmd_list[0] in available_commands: + available_commands[cmd_list[0]](cmd_list[1:]) + else: + print("Unknown command.") def main() -> None: """