Skip to content

Commit

Permalink
Add new simplification rule for invariant loop parameters.
Browse files Browse the repository at this point in the history
This was suggested by Cosmin to address some of the code produced by
AD.

The idea is that for a loop of the form

loop p = x ...
  ...stms...
  in res

we construct and simplify the body

  let p = x
  ...stms...
  in res

and if that simplifies to 'x', then we conclude that the loop
parameter 'p' must be invariant to the loop and simply bind it (and
the loop result) to 'x'.

Complication: for multi-parameter loops, we must also check that
the *original* computation of 'res' does *only* depends on other
invariant loop parameters.

Currently we do this only for loops that have a constant as one of
their initial loop parameter values.

The main downside of this rule is that doing recursive simplification
is quite expensive.  Especially after sequentialisation, pretty much
every 'reduce' will have been turned into a loop that triggers this
rule (although the rule itself will fail in most cases, after doing
the simplification).  Therefore I'm a bit hesitant to enable it as is.
Sure, the Futhark compiler is slow and it was never meant to be fast,
but it is still quite easy for the compiler to become *uselessly slow*
if we are not careful.  E.g. on OptionPricing, this rule itself makes
compilation 10% slower (and does not actually optimise anything - this
is purely the cost of failing checks).
  • Loading branch information
athas committed Aug 3, 2023
1 parent ee56f96 commit 3127997
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 42 deletions.
19 changes: 16 additions & 3 deletions src/Futhark/Optimise/Simplify/Engine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ module Futhark.Optimise.Simplify.Engine
bindLParams,
simplifyBody,
ST.SymbolTable,
hoistStms,
blockIf,
blockMigrated,
enterLoop,
Expand Down Expand Up @@ -206,6 +205,18 @@ asksEngineEnv f = f <$> askEngineEnv
askVtable :: SimpleM rep (ST.SymbolTable (Wise rep))
askVtable = asksEngineEnv envVtable

mkSubSimplify :: SimplifiableRep rep => SimpleM rep (SubSimplify (Wise rep))
mkSubSimplify = do
(ops, env) <- ask
pure $ \body -> do
scope <- askScope
let env' = env {envVtable = ST.fromScope scope}
(x, _) <- modifyNameSource $ runSimpleM (f body) ops env'
pure x
where
f body =
simplifyBodyNoHoisting mempty (map (const mempty) (bodyResult body)) body

localVtable ::
(ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) ->
SimpleM rep a ->
Expand Down Expand Up @@ -486,7 +497,8 @@ hoistStms rules block orig_stms final = do

process usageInStm stm stms usage x = do
vtable <- askVtable
res <- bottomUpSimplifyStm rules (vtable, usage) stm
ss <- mkSubSimplify
res <- bottomUpSimplifyStm ss rules (vtable, usage) stm
case res of
Nothing -- Nothing to optimise - see if hoistable.
| block vtable usage stm ->
Expand Down Expand Up @@ -517,7 +529,8 @@ hoistStms rules block orig_stms final = do
stms_h' <- nonrecSimplifyStm stms_h

vtable <- askVtable
simplified <- topDownSimplifyStm rules vtable stms_h'
ss <- mkSubSimplify
simplified <- topDownSimplifyStm ss rules vtable stms_h'

case simplified of
Just newstms -> do
Expand Down
88 changes: 55 additions & 33 deletions src/Futhark/Optimise/Simplify/Rule.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ module Futhark.Optimise.Simplify.Rule
RuleM,
cannotSimplify,
liftMaybe,
SubSimplify,
subSimplify,

-- * Rule definition
Rule (..),
Expand Down Expand Up @@ -53,19 +55,26 @@ module Futhark.Optimise.Simplify.Rule
)
where

import Control.Monad.Reader
import Control.Monad.State
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Builder
import Futhark.IR

-- | An action for recursively simplifying a body.
type SubSimplify rep = Body rep -> RuleM rep (Body rep)

newtype RuleEnv rep = RuleEnv {envSubSimplify :: SubSimplify rep}

-- | The monad in which simplification rules are evaluated.
newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource Maybe) a)
newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource (ReaderT (RuleEnv rep) Maybe)) a)
deriving
( Functor,
Applicative,
Monad,
MonadFreshNames,
MonadReader (RuleEnv rep),
HasScope rep,
LocalScope rep
)
Expand All @@ -84,19 +93,29 @@ instance (BuilderOps rep) => MonadBuilder (RuleM rep) where
simplify ::
Scope rep ->
VNameSource ->
RuleEnv rep ->
Rule rep ->
Maybe (Stms rep, VNameSource)
simplify _ _ Skip = Nothing
simplify scope src (Simplify (RuleM m)) =
runStateT (runBuilderT_ m scope) src
simplify _ _ _ Skip = Nothing
simplify scope src env (Simplify (RuleM m)) =
runReaderT (runStateT (runBuilderT_ m scope) src) env

-- | Abort the current attempt at simplification.
cannotSimplify :: RuleM rep a
cannotSimplify = RuleM $ lift $ lift Nothing
cannotSimplify = RuleM $ lift $ lift $ lift Nothing

liftMaybe :: Maybe a -> RuleM rep a
liftMaybe Nothing = cannotSimplify
liftMaybe (Just x) = pure x

-- | Recursively apply the simplifier on this body, using the current
-- rulebook. This can be quite costly, so think carefully before
-- doing this.
subSimplify :: SubSimplify rep
subSimplify body = do
s <- asks envSubSimplify
s body

-- | An efficient way of encoding whether a simplification rule should even be attempted.
data Rule rep
= -- | Give it a shot.
Expand Down Expand Up @@ -252,31 +271,6 @@ ruleBook topdowns bottomups =
forOp RuleGeneric {} = True
forOp _ = False

-- | @simplifyStm lookup stm@ performs simplification of the
-- binding @stm@. If simplification is possible, a replacement list
-- of bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
topDownSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
RuleBook rep ->
ST.SymbolTable rep ->
Stm rep ->
m (Maybe (Stms rep))
topDownSimplifyStm = applyRules . bookTopDownRules

-- | @simplifyStm uses stm@ performs simplification of the binding
-- @stm@. If simplification is possible, a replacement list of
-- bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
-- The first argument is the set of names used after this binding.
bottomUpSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
RuleBook rep ->
(ST.SymbolTable rep, UT.UsageTable) ->
Stm rep ->
m (Maybe (Stms rep))
bottomUpSimplifyStm = applyRules . bookBottomUpRules

rulesForStm :: Stm rep -> Rules rep a -> [SimplificationRule rep a]
rulesForStm stm = case stmExp stm of
BasicOp {} -> rulesBasicOp
Expand All @@ -299,19 +293,47 @@ applyRule _ _ _ =

applyRules ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
SubSimplify rep ->
Rules rep a ->
a ->
Stm rep ->
m (Maybe (Stms rep))
applyRules all_rules context stm = do
applyRules ss all_rules context stm = do
scope <- askScope

let env = RuleEnv ss
modifyNameSource $ \src ->
let applyRules' [] = Nothing
applyRules' (rule : rules) =
case simplify scope src (applyRule rule context stm) of
case simplify scope src env (applyRule rule context stm) of
Just x -> Just x
Nothing -> applyRules' rules
in case applyRules' $ rulesForStm stm all_rules of
Just (stms, src') -> (Just stms, src')
Nothing -> (Nothing, src)

-- | @simplifyStm lookup stm@ performs simplification of the
-- binding @stm@. If simplification is possible, a replacement list
-- of bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
topDownSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
SubSimplify rep ->
RuleBook rep ->
ST.SymbolTable rep ->
Stm rep ->
m (Maybe (Stms rep))
topDownSimplifyStm ss = applyRules ss . bookTopDownRules

-- | @simplifyStm uses stm@ performs simplification of the binding
-- @stm@. If simplification is possible, a replacement list of
-- bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
-- The first argument is the set of names used after this binding.
bottomUpSimplifyStm ::
(MonadFreshNames m, HasScope rep m, PrettyRep rep) =>
SubSimplify rep ->
RuleBook rep ->
(ST.SymbolTable rep, UT.UsageTable) ->
Stm rep ->
m (Maybe (Stms rep))
bottomUpSimplifyStm ss = applyRules ss . bookBottomUpRules
79 changes: 77 additions & 2 deletions src/Futhark/Optimise/Simplify/Rules/Loop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where

import Control.Monad
import Data.Bifunctor (second)
import Data.Bifunctor (first, second)
import Data.List (partition)
import Data.Map qualified as M
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
Expand Down Expand Up @@ -83,6 +84,79 @@ removeRedundantMergeVariables (_, used) pat aux (merge, form, body)
removeRedundantMergeVariables _ _ _ _ =
Skip

-- For a loop of the form
--
-- loop p = x ...
-- ...stms...
-- in res
--
-- we construct and simplify the body
--
-- let p = x
-- ...stms...
-- in res
--
-- and if that simplifies to 'x', then we conclude that the loop
-- parameter 'p' must be invariant to the loop and simply bind it (and
-- the loop result) to 'x'.
--
-- Complication: for multi-parameter loops, we must also check that
-- the *original* computation of 'res' does *only* depends on other
-- invariant loop parameters. See tests/loops/invariant1.fut for an
-- example.
simplifyInvariantParams :: BuilderOps rep => TopDownRuleLoop rep
simplifyInvariantParams _vtable pat aux (params, form, loopbody)
| consts <- filter constInit params,
not $ null consts = Simplify . auxing aux $
localScope (scopeOfFParams (map fst params) <> scopeOf form) $ do
loopbody_simpl <- subSimplify <=< buildBody_ $ do
mapM_ bindParam consts
bodyBind loopbody
let inv_pnames = determineInvariant $ bodyResult loopbody_simpl
invariant (_, (p, _), _) = paramName p `elem` inv_pnames
(inv, var) =
partition invariant $
zip3 (patElems pat) params (bodyResult loopbody)
(var_pes, var_params, var_res) = unzip3 var
when (null inv) cannotSimplify
mapM_ bindInv inv
loopbody' <- mkBodyM (bodyStms loopbody) var_res
letBind (Pat var_pes) $ Loop var_params form loopbody'
| otherwise = Skip
where
loopbody_deps = dataDependencies loopbody
resDep (Var v) = oneName v <> fromMaybe mempty (M.lookup v loopbody_deps)
resDep _ = mempty
res_deps = map (resDep . resSubExp) $ bodyResult loopbody

constInit (_, Constant {}) = True
constInit _ = False

bindParam (p, se) = letBindNames [paramName p] $ BasicOp $ SubExp se

bindInv (pe, (p, se), _) = do
letBindNames [patElemName pe] $ BasicOp $ SubExp se
letBindNames [paramName p] $ BasicOp $ SubExp se

resIsInvariant ((_, x), x') = x == resSubExp x'

depOnVar var (_, deps) = any (`nameIn` deps) var

noInvDepOnVar inv var
| (inv_var, inv') <- partition (depOnVar var) inv,
not $ null inv_var =
noInvDepOnVar inv' $ map fst inv_var <> var
| otherwise =
map fst inv

determineInvariant simpl_res =
let (inv, var) =
partition (resIsInvariant . fst) $
zip (zip (map (first paramName) params) simpl_res) res_deps
in noInvDepOnVar
(map (first (fst . fst)) inv)
(map (fst . fst . fst) var)

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BuilderOps rep => TopDownRuleLoop rep
Expand Down Expand Up @@ -290,7 +364,8 @@ topDownRules =
[ RuleLoop hoistLoopInvariantMergeVariables,
RuleLoop simplifyClosedFormLoop,
RuleLoop simplifyKnownIterationLoop,
RuleLoop simplifyLoopVariables
RuleLoop simplifyLoopVariables,
RuleLoop simplifyInvariantParams
]

bottomUpRules :: BuilderOps rep => [BottomUpRule rep]
Expand Down
12 changes: 12 additions & 0 deletions tests/loops/invariant0.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- Removal of invariant of invariant loop parameter (and eventually entire loop).
-- ==
-- structure { DoLoop 0 }

entry main [n] (bs: [n]bool) =
let res =
loop (x, y) = (0i32, false)
for i < n do
let y' = bs[i] && y
let x' = x + (i32.bool y')
in (x', y')
in res
13 changes: 13 additions & 0 deletions tests/loops/invariant1.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- Not actually invariant if you look carefully!
-- ==
-- input { 0 } output { 0 false }
-- input { 4 } output { 3 true }

entry main (n: i32) =
let res =
loop (x, y) = (0i32, false)
for _i < n do
let x' = if y then x + 1 else x
let y' = y || true
in (x', y')
in res
6 changes: 2 additions & 4 deletions tests/fibloop.fut → tests/loops/invariant2.fut
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
-- Also not actually invariant.
-- ==
-- input { 0 } output { 1 }
-- input { 10 } output { 89 }


def fib(n: i32): i32 =
entry main (n: i32) =
let (x,_) = loop (x, y) = (1,1) for _i < n do (y, x+y)
in x

def main(n: i32): i32 = fib(n)

0 comments on commit 3127997

Please sign in to comment.