Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add errors and assertions to Acc computations #494

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions accelerate.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ library
Data.Array.Accelerate.Data.Maybe
Data.Array.Accelerate.Data.Monoid
Data.Array.Accelerate.Data.Ratio
Data.Array.Accelerate.Debug.Assert
Data.Array.Accelerate.Debug.Error
Data.Array.Accelerate.Debug.Trace
Data.Array.Accelerate.Unsafe

Expand Down
9 changes: 9 additions & 0 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where
-> acc aenv arrs2
-> PreOpenAcc acc aenv arrs2

Aerror :: ArraysR (arrs2)
-> Message arrs1
-> acc aenv arrs1
-> PreOpenAcc acc aenv arrs2

-- Array inlet. Triggers (possibly) asynchronous host->device transfer if
-- necessary.
--
Expand Down Expand Up @@ -768,6 +773,7 @@ instance HasArraysR acc => HasArraysR (PreOpenAcc acc) where
arraysR (Apair as bs) = TupRpair (arraysR as) (arraysR bs)
arraysR Anil = TupRunit
arraysR (Atrace _ _ bs) = arraysR bs
arraysR (Aerror r _ _) = r
arraysR (Apply aR _ _) = aR
arraysR (Aforeign r _ _ _) = r
arraysR (Acond _ a _) = arraysR a
Expand Down Expand Up @@ -992,6 +998,7 @@ rnfPreOpenAcc rnfA pacc =
Apair as bs -> rnfA as `seq` rnfA bs
Anil -> ()
Atrace msg as bs -> rnfM msg `seq` rnfA as `seq` rnfA bs
Aerror repr msg as -> rnfTupR rnfArrayR repr `seq` rnfM msg `seq` rnfA as
Apply repr afun acc -> rnfTupR rnfArrayR repr `seq` rnfAF afun `seq` rnfA acc
Aforeign repr asm afun a -> rnfTupR rnfArrayR repr `seq` rnf (strForeign asm) `seq` rnfAF afun `seq` rnfA a
Acond p a1 a2 -> rnfE p `seq` rnfA a1 `seq` rnfA a2
Expand Down Expand Up @@ -1200,6 +1207,7 @@ liftPreOpenAcc liftA pacc =
Apair as bs -> [|| Apair $$(liftA as) $$(liftA bs) ||]
Anil -> [|| Anil ||]
Atrace msg as bs -> [|| Atrace $$(liftMessage (arraysR as) msg) $$(liftA as) $$(liftA bs) ||]
Aerror repr msg as -> [|| Aerror $$(liftArraysR repr) $$(liftMessage (arraysR as) msg) $$(liftA as) ||]
Apply repr f a -> [|| Apply $$(liftArraysR repr) $$(liftAF f) $$(liftA a) ||]
Aforeign repr asm f a -> [|| Aforeign $$(liftArraysR repr) $$(liftForeign asm) $$(liftPreOpenAfun liftA f) $$(liftA a) ||]
Acond p t e -> [|| Acond $$(liftE p) $$(liftA t) $$(liftA e) ||]
Expand Down Expand Up @@ -1396,6 +1404,7 @@ showPreAccOp Alet{} = "Alet"
showPreAccOp (Avar (Var _ ix)) = "Avar a" ++ show (idxToInt ix)
showPreAccOp (Use aR a) = "Use " ++ showArrayShort 5 (showsElt (arrayRtype aR)) aR a
showPreAccOp Atrace{} = "Atrace"
showPreAccOp Aerror{} = "Aerror"
showPreAccOp Apply{} = "Apply"
showPreAccOp Aforeign{} = "Aforeign"
showPreAccOp Acond{} = "Acond"
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Analysis/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ encodePreOpenAcc options encodeAcc pacc =
Apair a1 a2 -> intHost $(hashQ "Apair") <> travA a1 <> travA a2
Anil -> intHost $(hashQ "Anil")
Atrace (Message _ _ msg) as bs -> intHost $(hashQ "Atrace") <> intHost (Hashable.hash msg) <> travA as <> travA bs
Aerror r (Message _ _ msg) as -> intHost $(hashQ "Aerror") <> encodeArraysType r <> intHost (Hashable.hash msg) <> travA as
Apply _ f a -> intHost $(hashQ "Apply") <> travAF f <> travA a
Aforeign _ _ f a -> intHost $(hashQ "Aforeign") <> travAF f <> travA a
Use repr a -> intHost $(hashQ "Use") <> encodeArrayType repr <> deep (encodeArray a)
Expand Down
122 changes: 122 additions & 0 deletions src/Data/Array/Accelerate/Debug/Assert.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
-- |
-- Module : Data.Array.Accelerate.Debug.Assert
-- Copyright : [2008..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <[email protected]>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
-- Functions for checking properties or invariants
-- of a program.
--
-- @since 1.4.0.0
--

module Data.Array.Accelerate.Debug.Assert (

-- * Assertions
-- $assertions
--
assert, Assertion,
expEqual, AssertEqual,
arraysEqual, AssertArraysEqual,

) where

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array as S
import Data.Array.Accelerate.Sugar.Elt
import qualified Data.Array.Accelerate.Representation.Array as R
import qualified Data.Array.Accelerate.Representation.Shape as R


-- $assertions
--
-- The 'assert' function verifies whether a predicate holds and will stop
-- the execution of the array computation if the assertion does not hold.
-- It will then also print the given error message to the console.
--
-- The predicate can be passed as a boolean expression ('Exp Bool'), but we
-- have specialized assertions for array equivalence ('arraysEqual') and
-- scalar equivalence ('expEqual').
--

-- Verifies whether the predicate holds, before the computation can continue
-- with the result of the last argument. If the assertion does not hold,
-- it will stop the array computation and print the error message.
--
assert :: forall a bs. (Assertion a bs, Arrays bs) => String -> a -> Acc bs -> Acc bs
assert text assertion result
= A.acond (assertionCondition assertion result) result
$ Acc
$ SmartAcc
$ Aerror (S.arraysR @bs)
(assertionMessage @a @bs $ "Assertion failed: " ++ text)
arg
where
Acc arg = assertionArg assertion result

class Arrays (AssertionArg a res) => Assertion a res where
type AssertionArg a res

assertionArg :: a -> Acc res -> Acc (AssertionArg a res)
assertionMessage :: String -> Message (ArraysR (AssertionArg a res))
assertionCondition :: a -> Acc res -> Exp Bool

instance Assertion (Exp Bool) res where
type AssertionArg (Exp Bool) res = ()

assertionArg _ _ = Acc (SmartAcc Anil)
assertionMessage = Message (\_ -> "") (Just [|| \_ -> "" ||])
assertionCondition = const

instance Assertion (Acc (Scalar Bool)) res where
type AssertionArg (Acc (Scalar Bool)) res = ()

assertionArg _ _ = Acc (SmartAcc Anil)
assertionMessage = Message (\_ -> "") (Just [|| \_ -> "" ||])
assertionCondition a _ = A.the a

instance (Assertion a (), Show res, Arrays res) => Assertion (Acc res -> a) res where
type AssertionArg (Acc res -> a) res = res

assertionArg _ res = res
assertionMessage = Message (\xs -> "\n" ++ show (toArr @res xs))
(Just [||(\xs -> "\n" ++ show (toArr @res xs)) ||])
assertionCondition f res = assertionCondition (f res) (Acc (SmartAcc Anil) :: Acc ())

data AssertEqual e = AssertEqual (Exp e) (Exp e)

expEqual :: Exp e -> Exp e -> AssertEqual e
expEqual = AssertEqual

instance (Elt e, A.Eq e, Show e) => Assertion (AssertEqual e) res where
type AssertionArg (AssertEqual e) res = Scalar (e, e)

assertionArg (AssertEqual a b) _ = A.unit (A.T2 a b)
assertionMessage = Message (\e -> let (a, b) = toElt @(e, e) (R.indexArray (R.ArrayR R.dim0 (eltR @(e, e))) e ()) in show a ++ " does not equal " ++ show b)
(Just [||(\e -> let (a, b) = toElt @(e, e) (R.indexArray (R.ArrayR R.dim0 (eltR @(e, e))) e ()) in show a ++ " does not equal " ++ show b) ||])
assertionCondition (AssertEqual a b) _ = a A.== b

data AssertArraysEqual as = AssertArraysEqual (Acc as) (Acc as)

arraysEqual :: Acc as -> Acc as -> AssertArraysEqual as
arraysEqual = AssertArraysEqual

instance (Show sh, Show e, A.Shape sh, Elt e, A.Eq sh, A.Eq e) => Assertion (AssertArraysEqual (S.Array sh e)) res where
type AssertionArg (AssertArraysEqual (S.Array sh e)) res = (S.Array sh e, S.Array sh e)

assertionArg (AssertArraysEqual xs ys) _ = A.T2 xs ys
assertionMessage = Message (\(((), xs), ys) -> "\n" ++ show (toArr @(S.Array sh e) xs) ++ "\ndoes not equal\n" ++ show (toArr @(S.Array sh e) ys))
(Just [||(\(((), xs), ys) -> "\n" ++ show (toArr @(S.Array sh e) xs) ++ "\ndoes not equal\n" ++ show (toArr @(S.Array sh e) ys)) ||])
assertionCondition (AssertArraysEqual xs ys) _ = (A.shape xs A.== A.shape ys) A.&& A.the (A.all id $ A.reshape (A.I1 $ A.size xs) $ A.zipWith (A.==) xs ys)
81 changes: 81 additions & 0 deletions src/Data/Array/Accelerate/Debug/Error.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
-- |
-- Module : Data.Array.Accelerate.Debug.Error
-- Copyright : [2008..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <[email protected]>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
-- Functions for checking properties or invariants
-- of a program.
--
-- @since 1.4.0.0
--

module Data.Array.Accelerate.Debug.Error (

-- * Throwing errors
-- $errors
--
aerror, aerrorArray, aerrorExp

) where

import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array as S
import Data.Array.Accelerate.Sugar.Elt
import qualified Data.Array.Accelerate.Representation.Array as R
import qualified Data.Array.Accelerate.Representation.Shape as R


-- $errors
--
-- The 'aerror', 'aerrorArray', and 'aerrorExp' functions abort the execution
-- of the array program and print errors to an output stream. They are intended
-- for stopping the program when the program is in some invalid state, which
-- was expected to be unreachable.
--
-- Besides printing a given error message, it can also print the contents of an
-- array (with 'aerrorArray') or print some scalar value ('aerrorExp').
--

-- | Stops execution of the array computation and outputs the error message to
-- the console.
--
aerror :: forall a. Arrays a => String -> Acc a
aerror message
= Acc
$ SmartAcc
$ Aerror (S.arraysR @a)
(Message (\_ -> "") (Just [|| \_ -> "" ||]) message)
(SmartAcc Anil :: SmartAcc ())

-- | Outputs the trace message and the array(s) from the second argument to
-- the console, before the 'Acc' computation proceeds with the result of
-- the third argument.
--
aerrorArray :: forall a b. (Arrays a, Arrays b, Show a) => String -> Acc a -> Acc b
aerrorArray message (Acc inspect)
= Acc
$ SmartAcc
$ Aerror (S.arraysR @b)
(Message (show . toArr @a)
(Just [|| show . toArr @a ||]) message) inspect

-- | Outputs the trace message and a scalar value to the console, before
-- the 'Acc' computation proceeds with the result of the third argument.
--
aerrorExp :: forall e a. (Elt e, Show e, Arrays a) => String -> Exp e -> Acc a
aerrorExp message value =
let Acc inspect = unit value
in Acc
$ SmartAcc
$ Aerror (S.arraysR @a)
(Message (\a -> show (toElt @e (R.indexArray (R.ArrayR R.dim0 (eltR @e)) a ())))
(Just [|| \a -> show (toElt @e (R.indexArray (R.ArrayR R.dim0 (eltR @e)) a ())) ||]) message) inspect

9 changes: 8 additions & 1 deletion src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module Data.Array.Accelerate.Interpreter (
run, run1, runN,

-- Internal (hidden)
evalPrim, evalPrimConst, evalCoerceScalar, atraceOp,
evalPrim, evalPrimConst, evalCoerceScalar, atraceOp, aerrorOp,

) where

Expand Down Expand Up @@ -210,6 +210,7 @@ evalOpenAcc (AST.Manifest pacc) aenv =
(TupRpair r1 r2, (a1, a2))
Anil -> (TupRunit, ())
Atrace msg as bs -> unsafePerformIO $ manifest bs <$ atraceOp msg (snd $ manifest as)
Aerror _ msg as -> aerrorOp msg $ snd $ manifest as
Apply repr afun acc -> (repr, evalOpenAfun afun aenv $ snd $ manifest acc)
Aforeign repr _ afun acc -> (repr, evalOpenAfun afun Empty $ snd $ manifest acc)
Acond p acc1 acc2
Expand Down Expand Up @@ -874,6 +875,12 @@ atraceOp (Message show _ msg) as =
then traceIO msg
else traceIO $ printf "%s: %s" msg str

aerrorOp :: Message as -> as -> bs
aerrorOp (Message show _ msg) as =
let str = show as
in if null str
then error msg
else error $ printf "%s: %s" msg str

-- Scalar expression evaluation
-- ----------------------------
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Pretty/Graphviz.hs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) =

Anil -> "()" .$ []
Atrace (Message _ _ msg) as bs -> "atrace" .$ [ return $ PDoc (fromString msg) [], ppA as, ppA bs ]
Aerror _ (Message _ _ msg) as -> "aerror" .$ [ return $ PDoc (fromString msg) [], ppA as ]
Use repr arr -> "use" .$ [ return $ PDoc (prettyArray repr arr) [] ]
Unit _ e -> "unit" .$ [ ppE e ]
Generate _ sh f -> "generate" .$ [ ppE sh, ppF f ]
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Pretty/Print.hs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ prettyPreOpenAcc config ctx prettyAcc extractAcc aenv pacc =


Atrace (Message _ _ msg) as bs -> ppN "atrace" .$ [ fromString (show msg), ppA as, ppA bs ]
Aerror _ (Message _ _ msg) as -> ppN "aerror" .$ [ fromString (show msg), ppA as ]
Aforeign _ ff _ a -> ppN "aforeign" .$ [ pretty (strForeign ff), ppA a ]
Awhile p f a -> ppN "awhile" .$ [ ppAF p, ppAF f, ppA a ]
Use repr arr -> ppN "use" .$ [ prettyArray repr arr ]
Expand Down
7 changes: 7 additions & 0 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ data PreSmartAcc acc exp as where
-> acc arrs2
-> PreSmartAcc acc exp arrs2

Aerror :: ArraysR arrs2
-> Message arrs1
-> acc arrs1
-> PreSmartAcc acc exp arrs2

Use :: ArrayR (Array sh e)
-> Array sh e
-> PreSmartAcc acc exp (Array sh e)
Expand Down Expand Up @@ -805,6 +810,7 @@ instance HasArraysR acc => HasArraysR (PreSmartAcc acc exp) where
PairIdxRight -> t2
Aprj _ _ -> error "Ejector seat? You're joking!"
Atrace _ _ a -> arraysR a
Aerror repr _ _ -> repr
Use repr _ -> TupRsingle repr
Unit tp _ -> TupRsingle $ ArrayR ShapeRz $ tp
Generate repr _ _ -> TupRsingle repr
Expand Down Expand Up @@ -1315,6 +1321,7 @@ showPreAccOp Apair{} = "Apair"
showPreAccOp Anil{} = "Anil"
showPreAccOp Aprj{} = "Aprj"
showPreAccOp Atrace{} = "Atrace"
showPreAccOp Aerror{} = "Aerror"
showPreAccOp Unit{} = "Unit"
showPreAccOp Generate{} = "Generate"
showPreAccOp Reshape{} = "Reshape"
Expand Down
3 changes: 3 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Fusion.hs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ manifest config (OpenAcc pacc) =
Apair a1 a2 -> Apair (manifest config a1) (manifest config a2)
Anil -> Anil
Atrace msg a1 a2 -> Atrace msg (manifest config a1) (manifest config a2)
Aerror repr msg a1 -> Aerror repr msg (manifest config a1)
Apply repr f a -> apply repr (cvtAF f) (manifest config a)
Aforeign repr ff f a -> Aforeign repr ff (cvtAF f) (manifest config a)

Expand Down Expand Up @@ -370,6 +371,7 @@ embedPreOpenAcc config matchAcc embedAcc elimAcc pacc
Awhile p f a -> done $ Awhile (cvtAF p) (cvtAF f) (cvtA a)
Apair a1 a2 -> done $ Apair (cvtA a1) (cvtA a2)
Atrace msg a1 a2 -> done $ Atrace msg (cvtA a1) (cvtA a2)
Aerror repr msg a1 -> done $ Aerror repr msg (cvtA a1)
Aforeign aR ff f a -> done $ Aforeign aR ff (cvtAF f) (cvtA a)
-- Collect s -> collectD s

Expand Down Expand Up @@ -1548,6 +1550,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en
Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae)
Anil -> Anil
Atrace msg a b -> Atrace msg (cvtA a) (cvtA b)
Aerror repr msg a -> Aerror repr msg (cvtA a)
Apair a1 a2 -> Apair (cvtA a1) (cvtA a2)
Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (cvtA a)
Apply repr f a -> Apply repr (cvtAF f) (cvtA a)
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Trafo/LetSplit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ convertPreOpenAcc = \case
Apair a1 a2 -> Apair (convertAcc a1) (convertAcc a2)
Anil -> Anil
Atrace msg as bs -> Atrace msg (convertAcc as) (convertAcc bs)
Aerror repr msg as -> Aerror repr msg (convertAcc as)
Apply repr f a -> Apply repr (convertAfun f) (convertAcc a)
Aforeign repr asm f a -> Aforeign repr asm (convertAfun f) (convertAcc a)
Acond e a1 a2 -> Acond e (convertAcc a1) (convertAcc a2)
Expand Down
Loading