Skip to content

Commit

Permalink
Solve calldata per method
Browse files Browse the repository at this point in the history
  • Loading branch information
arcz committed Apr 18, 2023
1 parent 36c4058 commit e65dcb2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 45 deletions.
96 changes: 54 additions & 42 deletions lib/Echidna/SymExec.hs
Original file line number Diff line number Diff line change
@@ -1,56 +1,68 @@
{-# LANGUAGE DataKinds #-}

module Echidna.SymExec where

import Control.Monad (forM)
import Control.Monad.State.Strict (evalStateT)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BS
import Data.Functor ((<&>))
import Data.Map qualified as Map
import Data.Maybe (catMaybes, fromMaybe)
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Text qualified as T

import EVM (StorageModel(..))
import EVM.ABI
import EVM.Expr (simplify)
import EVM.Fetch qualified as Fetch
import EVM.Solidity (SolcContract(..))
import EVM.SMT
import EVM.Solidity (SolcContract(..), Method(..))
import EVM.Solvers (withSolvers, Solver(Z3), CheckSatResult(Sat))
import EVM.SymExec (interpret, runExpr, abstractVM, mkCalldata, produceModels)
import EVM.SymExec (interpret, runExpr, abstractVM, mkCalldata, produceModels, Sig(Sig))
import EVM.Types
import EVM.SMT

import Echidna.Types.Tx

exploreContract :: Addr -> SolcContract -> IO [Tx]
exploreContract dst contract = do
let
calldata = mkCalldata Nothing [] -- fully abstract calldata, solution in buffers.txdata
-- Alternatively, more concrete calldata with a specific function,
-- not sure what is better yet
-- calldata = mkCalldata (Just (Sig method.methodSignature (snd <$> method.inputs))) []
vmSym = abstractVM calldata contract.runtimeCode Nothing SymbolicS
maxIter = Just 10
askSmtIters = Just 5
rpcInfo = Nothing
timeout = Nothing -- Just 1000 -- is it seconds?

withSolvers Z3 2 timeout $ \solvers -> do
exprInter <- evalStateT (interpret (Fetch.oracle solvers rpcInfo) maxIter askSmtIters runExpr) vmSym
models <- produceModels solvers (simplify exprInter)

txs <- forM models $ \(_end, result) ->
case result of
Sat cex ->
case Map.lookup (AbstractBuf "txdata") cex.buffers of
Just (Flat cd) -> do
let value = fromMaybe 0 $ Map.lookup (CallValue 0) cex.txContext
pure $ Just $
Tx { call = SolCalldata cd
, src = 0
, dst = dst
, gasprice = 0
, gas = maxGasPerBlock
, value = value
, delay = (0, 0)
}
Just (Comp _) -> pure Nothing -- TODO: compressed, implement me
Nothing -> pure Nothing

_ -> pure Nothing

pure $ catMaybes txs
let methods = Map.elems contract.abiMap
timeout = Just 30 -- seconds

res <- withSolvers Z3 2 timeout $ \solvers -> do
forM methods $ \method -> do
let
calldata = mkCalldata (Just (Sig method.methodSignature (snd <$> method.inputs))) []
vmSym = abstractVM calldata contract.runtimeCode Nothing AbstractStore
maxIter = Just 10
askSmtIters = Just 5
rpcInfo = Nothing

exprInter <- interpret (Fetch.oracle solvers rpcInfo) maxIter askSmtIters vmSym runExpr
models <- produceModels solvers (simplify exprInter)
pure $ mapMaybe (modelToTx dst method) models

pure $ mconcat res

modelToTx :: Addr -> Method -> (Expr 'End, CheckSatResult) -> Maybe Tx
modelToTx dst method (_end, result) =
case result of
Sat cex ->
let
args = (zip [1..] method.inputs) <&> \(i::Int, (_argName, argType)) ->
case Map.lookup (Var ("arg" <> T.pack (show i))) cex.vars of
Just w ->
decodeAbiValue argType (BS.fromStrict (word256Bytes w))
Nothing -> -- put a placeholder
decodeAbiValue argType (BS.repeat 0)

value = fromMaybe 0 $ Map.lookup (CallValue 0) cex.txContext

in Just Tx
{ call = SolCall (method.name, args)
, src = 0
, dst = dst
, gasprice = 0
, gas = maxGasPerBlock
, value = value
, delay = (0, 0)
}

_ -> Nothing
6 changes: 3 additions & 3 deletions tests/solidity/symbolic/sym.sol
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
contract VulnerableContract {
function func_one(int128 x) public payable {
function func_one(int256 x) public pure {
if (x / 4 == -20) {
assert(false); // BUG
}
}

function func_two(int128 x) public pure {
if ((x >> 30) / 7 == 2) {
function func_two(int128 x) public payable {
if ((msg.value >> 30) / 7 == 2) {
assert(false); // BUG
}
}
Expand Down

0 comments on commit e65dcb2

Please sign in to comment.