diff --git a/data/Switch.hx b/data/Switch.hx new file mode 100644 index 0000000..68fca56 --- /dev/null +++ b/data/Switch.hx @@ -0,0 +1,26 @@ +class Switch { + static function main() { + var a = 3; + var b = switch (a) { + case 0: a * 2; + case 3: a - 1; + default: a << 2; + } + switch (b) { + case 0: b += 1; + case 1: b += 1; + case 2: b += 1; + //case 4: b += 1; + } + /* + switch (a) { + case -1: b += 1; + case -5: b += 1; + } + var c = "hello"; + switch (c) { + case "hello": a += 1; + case "world": a += 1; + }*/ + } +} \ No newline at end of file diff --git a/src/decompiler/ast.rs b/src/decompiler/ast.rs index ddeb8a9..eba1f55 100644 --- a/src/decompiler/ast.rs +++ b/src/decompiler/ast.rs @@ -258,12 +258,10 @@ pub enum Statement { variable: Expr, assign: Expr, }, - // Call a void function (no assignment) - Call(Call), - /// Return an expression - Return(Expr), - /// Return nothing / early return - ReturnVoid, + /// Expression statement + ExprStatement(Expr), + /// Return an expression or nothing (void) + Return(Option), /// If statement If { cond: Expr, @@ -273,6 +271,11 @@ pub enum Statement { Else { stmts: Vec, }, + Switch { + arg: Expr, + default: Vec, + cases: Vec<(Expr, Vec)>, + }, /// While statement While { cond: Expr, @@ -282,3 +285,8 @@ pub enum Statement { Continue, Throw(Expr), } + +/// Create an expression statement +pub fn stmt(e: Expr) -> Statement { + Statement::ExprStatement(e) +} diff --git a/src/decompiler/fmt.rs b/src/decompiler/fmt.rs index 781d217..bc1e5d9 100644 --- a/src/decompiler/fmt.rs +++ b/src/decompiler/fmt.rs @@ -2,7 +2,7 @@ use std::fmt; use std::fmt::{Display, Formatter}; use crate::decompiler::ast::{ - Call, Class, Constant, ConstructorCall, Expr, Method, Operation, Statement, + Class, Constant, ConstructorCall, Expr, Method, Operation, Statement, }; use crate::types::{Function, RefField, Type}; use crate::Bytecode; @@ -245,13 +245,12 @@ impl Statement { } => { if *declaration { "var " } else { "" }{disp!(variable)}" = "{disp!(assign)}";" } - Statement::Call(Call { fun, args }) => { - {disp!(fun)}"("{fmtools::join(", ", args.iter().map(|e| disp!(e)))}");" + Statement::ExprStatement(expr) => { + {disp!(expr)}";" } Statement::Return(expr) => { - "return "{disp!(expr)}";" + "return" if let Some(e) = expr { " "{disp!(e)} } ";" } - Statement::ReturnVoid => "return;", Statement::If { cond, stmts } => { "if ("{disp!(cond)}") {\n" let indent2 = indent.inc_nesting(); @@ -268,6 +267,24 @@ impl Statement { } {indent}"}" } + Statement::Switch {arg, default, cases} => { + "switch ("{disp!(arg)}") {\n" + let indent2 = indent.inc_nesting(); + let indent3 = indent2.inc_nesting(); + if !default.is_empty() { + {indent2}"default:\n" + for stmt in default { + {indent3}{stmt.display(&indent3, code, f)}"\n" + } + } + for (pattern, stmts) in cases { + {indent2}"case "{disp!(pattern)}":\n" + for stmt in stmts { + {indent3}{stmt.display(&indent3, code, f)}"\n" + } + } + {indent}"}" + } Statement::While { cond, stmts } => { "while ("{disp!(cond)}") {\n" let indent2 = indent.inc_nesting(); diff --git a/src/decompiler/mod.rs b/src/decompiler/mod.rs index 37425d0..d581d31 100644 --- a/src/decompiler/mod.rs +++ b/src/decompiler/mod.rs @@ -92,28 +92,85 @@ pub fn decompile_class(code: &Bytecode, obj: &TypeObj) -> Class { } /// Helper to process a stack of scopes (branches, loops) -pub struct Scopes { +struct Scopes { // A linked list would be appreciable i think /// There is always at least one scope, the root scope scopes: Vec, } impl Scopes { - pub fn new() -> Self { + fn new() -> Self { Self { - scopes: vec![Scope { - ty: ScopeType::RootScope, - stmts: Vec::new(), - }], + scopes: vec![Scope::RootScope(Vec::new())], } } - pub fn pop_last_loop(&mut self) -> Option<(LoopScope, Vec)> { + fn push_if(&mut self, len: i32, cond: Expr) { + self.scopes.push(Scope::If { + len, + cond, + stmts: Vec::new(), + }) + } + + fn push_else(&mut self, len: i32) { + self.scopes.push(Scope::Else { + len, + stmts: Vec::new(), + }) + } + + fn push_loop(&mut self, start: usize) { + self.scopes.push(Scope::Loop(LoopScope { + start, + cond: None, + stmts: Vec::new(), + })) + } + + fn push_switch(&mut self, len: i32, arg: Expr, offsets: Vec) { + self.scopes.push(Scope::Switch(SwitchScope { + len, + arg, + offsets, + default: Vec::new(), + cases: Vec::new(), + })) + } + + fn push_switch_case(&mut self, cst: usize) { + let last = self.pop(); + match last { + Scope::Switch(switch) => { + self.scopes.push(Scope::Switch(switch)); + self.scopes.push(Scope::SwitchCase(SwitchCase { + pattern: cst_int(cst as i32), + stmts: Vec::new(), + })); + } + Scope::SwitchCase(case) => { + if let Scope::Switch(switch) = self.last_mut() { + switch.cases.push(case); + } else { + panic!("push switch case without switch ?\n{:#?}", self.scopes); + } + self.scopes.push(Scope::SwitchCase(SwitchCase { + pattern: cst_int(cst as i32), + stmts: Vec::new(), + })); + } + _ => { + self.scopes.push(last); + } + } + } + + fn pop_last_loop(&mut self) -> Option { for idx in (0..self.depth()).rev() { let scope = self.scopes.remove(idx); - match scope.ty { - ScopeType::Loop(l) => { - return Some((l, scope.stmts)); + match scope { + Scope::Loop(l) => { + return Some(l); } _ => { self.scopes.insert(idx, scope); @@ -123,10 +180,10 @@ impl Scopes { None } - pub fn last_loop(&self) -> Option<&LoopScope> { + fn last_loop(&self) -> Option<&LoopScope> { for idx in (0..self.depth()).rev() { - match &self.scopes[idx].ty { - ScopeType::Loop(l) => { + match &self.scopes[idx] { + Scope::Loop(l) => { return Some(l); } _ => {} @@ -135,112 +192,174 @@ impl Scopes { None } - pub fn last_mut(&mut self) -> &mut Scope { + fn last(&self) -> &Scope { + self.scopes.last().unwrap() + } + + fn last_mut(&mut self) -> &mut Scope { self.scopes.last_mut().unwrap() } - pub fn push_scope(&mut self, scope: Scope) { + fn push_scope(&mut self, scope: Scope) { self.scopes.push(scope); } - pub fn depth(&self) -> usize { + fn pop(&mut self) -> Scope { + self.scopes.pop().unwrap() + } + + fn depth(&self) -> usize { self.scopes.len() } - pub fn push_stmt(&mut self, mut stmt: Option) { + fn push_stmt(&mut self, mut stmt: Option) { // Start to iterate from the end ('for' because we need the index) for idx in (0..self.depth()).rev() { - let mut scope = self.scopes.remove(idx); - - if let Some(stmt) = stmt.take() { - scope.stmts.push(stmt); - } + let scope = self.scopes.remove(idx); // We only handle branches we know the length of // We can't know the end of a loop scope before seeing the jump back - match scope.ty { - ScopeType::Branch { mut len, cond } => { + match scope { + Scope::If { + mut len, + cond, + mut stmts, + } => { + if let Some(stmt) = stmt.take() { + stmts.push(stmt); + } + // Decrease scope len len -= 1; if len <= 0 { //println!("Decrease nesting {parent:?}"); - stmt = Some(Statement::If { - cond, - stmts: scope.stmts, - }); + stmt = Some(Statement::If { cond, stmts }); } else { // Scope continues - self.scopes.insert( - idx, - Scope { - ty: ScopeType::Branch { len, cond }, - stmts: scope.stmts, - }, - ); + self.scopes.insert(idx, Scope::If { len, cond, stmts }); } } - ScopeType::Else { mut len } => { + Scope::Else { mut len, mut stmts } => { + if let Some(stmt) = stmt.take() { + stmts.push(stmt); + } + // Decrease scope len len -= 1; if len <= 0 { //println!("Decrease nesting {parent:?}"); - stmt = Some(Statement::Else { stmts: scope.stmts }); + stmt = Some(Statement::Else { stmts }); } else { // Scope continues - self.scopes.insert( - idx, - Scope { - ty: ScopeType::Else { len }, - stmts: scope.stmts, - }, - ); + self.scopes.insert(idx, Scope::Else { len, stmts }); } } - _ => { - self.scopes.insert(idx, scope); + Scope::Switch(mut switch) => { + if let Some(stmt) = stmt.take() { + switch.default.push(stmt); + } + + switch.len -= 1; + if switch.len <= 0 { + stmt = Some(Statement::Switch { + arg: switch.arg, + default: switch.default, + cases: switch + .cases + .into_iter() + .map(|case| (case.pattern, case.stmts)) + .collect(), + }); + } else { + self.scopes.insert(idx, Scope::Switch(switch)); + } + } + Scope::SwitchCase(mut case) => { + if let Some(stmt) = stmt.take() { + case.stmts.push(stmt); + } + if let Scope::Switch(switch) = self.last_mut() { + if switch.len <= 1 { + switch.cases.push(case); + } else { + self.scopes.insert(idx, Scope::SwitchCase(case)); + } + } + } + Scope::RootScope(mut stmts) => { + if let Some(stmt) = stmt.take() { + stmts.push(stmt); + } + self.scopes.insert(idx, Scope::RootScope(stmts)); + } + Scope::Loop(mut loop_) => { + if let Some(stmt) = stmt.take() { + loop_.stmts.push(stmt); + } + self.scopes.insert(idx, Scope::Loop(loop_)); } } } if let Some(stmt) = stmt.take() { - self.last_mut().stmts.push(stmt); + match self.last_mut() { + Scope::RootScope(stmts) + | Scope::If { stmts, .. } + | Scope::Else { stmts, .. } + | Scope::Switch(SwitchScope { default: stmts, .. }) + | Scope::SwitchCase(SwitchCase { stmts, .. }) + | Scope::Loop(LoopScope { stmts, .. }) => { + stmts.push(stmt); + } + } } } - pub fn statements(mut self) -> Vec { - self.scopes.pop().unwrap().stmts - } -} - -pub struct Scope { - pub ty: ScopeType, - pub stmts: Vec, -} - -impl Scope { - pub fn new(ty: ScopeType) -> Self { - Self { - ty, - stmts: Vec::new(), + fn statements(mut self) -> Vec { + if let Scope::RootScope(stmts) = self.pop() { + stmts + } else { + unreachable!("mmmmhhh kinda sus:\n{:#?}\n", self.scopes); } } } -pub enum ScopeType { - RootScope, - Branch { len: i32, cond: Expr }, - Else { len: i32 }, +#[derive(Debug)] +enum Scope { + RootScope(Vec), + If { + len: i32, + cond: Expr, + stmts: Vec, + }, + Else { + len: i32, + stmts: Vec, + }, Loop(LoopScope), + Switch(SwitchScope), + SwitchCase(SwitchCase), } -pub struct LoopScope { - pub cond: Option, - pub start: usize, +#[derive(Debug)] +struct LoopScope { + cond: Option, + start: usize, + stmts: Vec, } -impl LoopScope { - pub fn new(start: usize) -> Self { - Self { cond: None, start } - } +#[derive(Debug)] +struct SwitchScope { + len: i32, + offsets: Vec, + arg: Expr, + default: Vec, + cases: Vec, +} + +#[derive(Debug)] +struct SwitchCase { + pattern: Expr, + stmts: Vec, } enum ExprCtx { @@ -265,6 +384,7 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { // Expression values for each registers let mut reg_state = HashMap::with_capacity(f.regs.len()); // For parsing statements made of multiple instructions like constructor calls and anonymous structures + // TODO move this to another pass on the generated ast let mut expr_ctx = Vec::new(); // Variable names we already declared let mut seen = HashSet::new(); @@ -314,6 +434,16 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { }; } + macro_rules! push_expr_stmt { + ($i:ident, $dst:ident, $e:expr) => { + if f.var_name(code, $i).is_some() { + push_stmt!(stmt($e)) + } else { + reg_state.insert($dst, $e); + } + }; + } + let missing_expr = || Expr::Unknown("missing expr".to_owned()); // Get the expr for a register @@ -345,22 +475,22 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { match $fun.resolve(code) { FunPtr::Fun(func) => { let call = if func.is_method() { - Call::new(Expr::Field(Box::new(expr!($arg0)), func.name.clone().unwrap().resolve(&code.strings).to_owned()), make_args!($($args),*)) + call(Expr::Field(Box::new(expr!($arg0)), func.name.clone().unwrap().resolve(&code.strings).to_owned()), make_args!($($args),*)) } else { - Call::new_fun($fun, make_args!($arg0 $(, $args)*)) + call_fun($fun, make_args!($arg0 $(, $args)*)) }; if func.ty(code).ret.is_void() { - push_stmt!(Statement::Call(call)); + push_stmt!(stmt(call)); } else { - push_expr!($i, $dst, Expr::Call(Box::new(call))); + push_expr!($i, $dst, call); } } FunPtr::Native(n) => { - let call = Call::new_fun($fun, make_args!($arg0 $(, $args)*)); + let call = call_fun($fun, make_args!($arg0 $(, $args)*)); if n.ty(code).ret.is_void() { - push_stmt!(Statement::Call(call)); + push_stmt!(stmt(call)); } else { - push_expr!($i, $dst, Expr::Call(Box::new(call))); + push_expr!($i, $dst, call); } } } @@ -374,27 +504,18 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { if $offset > 0 { // It's a loop if matches!(f.ops[$i + $offset as usize], Opcode::JAlways { offset } if offset < 0) { - if let ScopeType::Loop(loop_) = &mut scopes.last_mut().ty { + if let Scope::Loop(loop_) = scopes.last_mut() { if loop_.cond.is_none() { loop_.cond = Some($cond); } else { - scopes.push_scope(Scope::new(ScopeType::Branch { - len: $offset + 1, - cond: $cond, - })); + scopes.push_if($offset + 1, $cond); } } else { - scopes.push_scope(Scope::new(ScopeType::Branch { - len: $offset + 1, - cond: $cond, - })); + scopes.push_if($offset + 1, $cond); } } else { // It's an if - scopes.push_scope(Scope::new(ScopeType::Branch { - len: $offset + 1, - cond: $cond, - })); + scopes.push_if($offset + 1, $cond); } } } @@ -459,17 +580,17 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { push_expr!(i, dst, not(expr!(src))); } &Opcode::Incr { dst } => { - push_expr!(i, dst, incr(expr!(dst))); + push_expr_stmt!(i, dst, incr(expr!(dst))); } &Opcode::Decr { dst } => { - push_expr!(i, dst, decr(expr!(dst))); + push_expr_stmt!(i, dst, decr(expr!(dst))); } //endregion //region CALLS &Opcode::Call0 { dst, fun } => { if fun.ty(code).ret.is_void() { - push_stmt!(Statement::Call(Call::new_fun(fun, Vec::new()))); + push_stmt!(stmt(call_fun(fun, Vec::new()))); } else { push_expr!(i, dst, call_fun(fun, Vec::new())); } @@ -518,7 +639,7 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { } } else { if fun.ty(code).ret.is_void() { - push_stmt!(Statement::Call(Call::new_fun( + push_stmt!(stmt(call_fun( *fun, args.iter().map(|x| expr!(x)).collect::>(), ))); @@ -533,7 +654,7 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { } Opcode::CallMethod { dst, field, args } => { let method = f.regtype(args[0]).method(field.0, code).unwrap(); - let call = Call::new( + let call = call( Expr::Field( Box::new(expr!(args[0])), method.name.resolve(&code.strings).to_owned(), @@ -546,14 +667,14 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { .map(|fun| fun.ty(code).ret.is_void()) .unwrap_or(false) { - push_stmt!(Statement::Call(call)); + push_stmt!(stmt(call)); } else { - push_expr!(i, *dst, Expr::Call(Box::new(call))); + push_expr!(i, *dst, call); } } Opcode::CallThis { dst, field, args } => { let method = f.regs[0].method(field.0, code).unwrap(); - let call = Call::new( + let call = call( Expr::Field( Box::new(cst_this()), method.name.resolve(&code.strings).to_owned(), @@ -566,13 +687,13 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { .map(|fun| fun.ty(code).ret.is_void()) .unwrap_or(false) { - push_stmt!(Statement::Call(call)); + push_stmt!(stmt(call)); } else { - push_expr!(i, *dst, Expr::Call(Box::new(call))); + push_expr!(i, *dst, call); } } Opcode::CallClosure { dst, fun, args } => { - let call = Call::new( + let call = call( expr!(*fun), args.iter().map(|x| expr!(x)).collect::>(), ); @@ -581,9 +702,9 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { .map(|ty| ty.ret.is_void()) .unwrap_or(false) { - push_stmt!(Statement::Call(call)); + push_stmt!(stmt(call)); } else { - push_expr!(i, *dst, Expr::Call(Box::new(call))); + push_expr!(i, *dst, call); } } //endregion @@ -761,41 +882,78 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { // It's definitely a continue; statement push_stmt!(Statement::Continue); } else { - let (loop_, stmts) = scopes.pop_last_loop().unwrap(); + let loop_ = scopes.pop_last_loop().unwrap(); // It's the last jump backward of the loop, we generate the loop statement if let Some(cond) = loop_.cond { - push_stmt!(Statement::While { cond, stmts }); + push_stmt!(Statement::While { cond, stmts: loop_.stmts}); } else { push_stmt!(Statement::While { cond: cst_bool(true), - stmts + stmts: loop_.stmts }); } } } else { - // Check the instruction just before the jump target - // If it's a jump backward of a loop - if matches!(f.ops[(i as i32 + offset) as usize], Opcode::JAlways {offset} if offset < 0) - { - // It's a break condition - push_stmt!(Statement::Break); - } else { - // It's the jump over of an else clause - scopes.push_scope(Scope::new(ScopeType::Else { len: offset + 1 })); + match &scopes.last() { + Scope::Switch(switch) => { + if let Some(pos) = switch.offsets.iter().position(|o| *o == i) { + scopes.push_switch_case(pos); + } else { + panic!("no offset"); + } + } + Scope::SwitchCase(case) => { + let len = scopes.scopes.len(); + let penult = &mut scopes.scopes[len - 2]; + if let Scope::Switch(switch) = penult { + if let Some(pos) = switch.offsets.iter().position(|o| *o == i) { + scopes.push_switch_case(pos); + } else { + panic!("no offset"); + } + } else { + panic!("wtf"); + } + } + Scope::Loop(_) => { + // Check the instruction just before the jump target + // If it's a jump backward of a loop + if matches!(f.ops[(i as i32 + offset) as usize], Opcode::JAlways {offset} if offset < 0) + { + // It's a break condition + push_stmt!(Statement::Break); + } + } + Scope::If { .. } => { + // It's the jump over of an else clause + scopes.push_else(offset + 1); + } + _ => { + println!("JAlways > 0 with no matching scope ?"); + } } } } - &Opcode::Label => scopes.push_scope(Scope::new(ScopeType::Loop(LoopScope::new(i)))), + Opcode::Switch { reg, offsets, end } => { + // Convert to absolute positions + scopes.push_switch( + *end + 1, + expr!(reg), + offsets.iter().map(|o| i + *o as usize).collect(), + ); + // The default switch case is implicit + } + &Opcode::Label => scopes.push_loop(i), &Opcode::Ret { ret } => { // Do not display return void; only in case of an early return if scopes.depth() > 1 { - statement = Some(if f.regtype(ret).is_void() { - Statement::ReturnVoid + push_stmt!(Statement::Return(if f.regtype(ret).is_void() { + None } else { - Statement::Return(expr!(ret)) - }); + Some(expr!(ret)) + })); } else if !f.regtype(ret).is_void() { - statement = Some(Statement::Return(expr!(ret))); + push_stmt!(Statement::Return(Some(expr!(ret)))); } } //endregion @@ -878,9 +1036,5 @@ pub fn decompile_function(code: &Bytecode, f: &Function) -> Vec { scopes.push_stmt(statement.take()); } - if scopes.depth() != 1 { - println!("Wait a minute, not all scopes have finished"); - } - scopes.statements() }