From 704cb2548a97143655e713b78c240a479576f96a Mon Sep 17 00:00:00 2001 From: David Date: Thu, 25 Nov 2021 15:16:13 +0100 Subject: [PATCH 1/3] Unary Stencil --- src/Data/Array/Accelerate.hs | 12 +++++++----- src/Data/Array/Accelerate/Interpreter.hs | 15 +++++++++++++++ src/Data/Array/Accelerate/Language.hs | 14 +++++++++----- .../Accelerate/Representation/Stencil.hs | 19 +++++++++++++++++++ .../Array/Accelerate/Representation/Type.hs | 2 +- src/Data/Array/Accelerate/Smart.hs | 5 +++++ 6 files changed, 56 insertions(+), 11 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index ff1729f27..19476fc80 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -285,11 +285,13 @@ module Data.Array.Accelerate ( clamp, mirror, wrap, function, -- *** Common stencil patterns - Stencil3, Stencil5, Stencil7, Stencil9, - Stencil3x3, Stencil5x3, Stencil3x5, Stencil5x5, - Stencil3x3x3, Stencil5x3x3, Stencil3x5x3, Stencil3x3x5, Stencil5x5x3, Stencil5x3x5, - Stencil3x5x5, Stencil5x5x5, - + Stencil1, Stencil3, Stencil5, Stencil7, Stencil9, + Stencil1x1, Stencil1x3, Stencil1x5, + Stencil3x1, Stencil3x3, Stencil3x5, + Stencil5x1, Stencil5x3, Stencil5x5, + Stencil3x3x3, Stencil3x5x3, Stencil3x3x5, Stencil3x5x5, + Stencil5x3x3, Stencil5x5x3, Stencil5x3x5, Stencil5x5x5, + -- -- ** Sequence operations -- collect, diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 5b8e6401a..d75bdc36b 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -678,6 +678,14 @@ stencilAccess stencil = goR (stencilShapeR stencil) stencil -- dimension is Z. -- goR :: ShapeR sh -> StencilR sh e stencil -> (sh -> e) -> sh -> stencil + goR _ (StencilRunit1 _) rf ix = + let + (z, i) = ix + rf' d = rf (z, i+d) + in + ( () + , rf' 0 ) + goR _ (StencilRunit3 _) rf ix = let (z, i) = ix @@ -732,6 +740,13 @@ stencilAccess stencil = goR (stencilShapeR stencil) stencil -- when we recurse on the stencil structure we must manipulate the -- _left-most_ index component. -- + goR (ShapeRsnoc shr) (StencilRtup1 s1) rf ix = + let (i, ix') = uncons shr ix + rf' d ds = rf (cons shr (i+d) ds) + in + ( () + , goR shr s1 (rf' 0) ix') + goR (ShapeRsnoc shr) (StencilRtup3 s1 s2 s3) rf ix = let (i, ix') = uncons shr ix rf' d ds = rf (cons shr (i+d) ds) diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index 727e9f7b9..eb32ffeac 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -97,7 +97,7 @@ module Data.Array.Accelerate.Language ( -- * Conversions ord, chr, boolToInt, bitcast, -) where +Stencil1,Stencil1x3,Stencil1x5,Stencil1x1,Stencil3x1,Stencil5x1) where import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Pattern @@ -851,14 +851,20 @@ backpermute = Acc $$$ applyAcc (Backpermute $ shapeR @sh') -- -- DIM1 stencil type +type Stencil1 a = (Exp a) type Stencil3 a = (Exp a, Exp a, Exp a) type Stencil5 a = (Exp a, Exp a, Exp a, Exp a, Exp a) type Stencil7 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a) type Stencil9 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a) -- DIM2 stencil type +type Stencil1x1 a = (Stencil1 a) +type Stencil3x1 a = (Stencil3 a) +type Stencil5x1 a = (Stencil5 a) +type Stencil1x3 a = (Stencil1 a, Stencil1 a, Stencil1 a) type Stencil3x3 a = (Stencil3 a, Stencil3 a, Stencil3 a) type Stencil5x3 a = (Stencil5 a, Stencil5 a, Stencil5 a) +type Stencil1x5 a = (Stencil1 a, Stencil1 a, Stencil1 a, Stencil1 a, Stencil1 a) type Stencil3x5 a = (Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a) type Stencil5x5 a = (Stencil5 a, Stencil5 a, Stencil5 a, Stencil5 a, Stencil5 a) @@ -908,15 +914,13 @@ type Stencil5x5x5 a = (Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, S -- as a separable -- 2-pass operation. -- --- > type Stencil5x1 a = (Stencil3 a, Stencil5 a, Stencil3 a) --- > type Stencil1x5 a = (Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a) -- > -- > convolve5x1 :: Num a => [Exp a] -> Stencil5x1 a -> Exp a --- > convolve5x1 kernel (_, (a,b,c,d,e), _) +-- > convolve5x1 kernel (a,b,c,d,e) -- > = Prelude.sum $ Prelude.zipWith (*) kernel [a,b,c,d,e] -- > -- > convolve1x5 :: Num a => [Exp a] -> Stencil1x5 a -> Exp a --- > convolve1x5 kernel ((_,a,_), (_,b,_), (_,c,_), (_,d,_), (_,e,_)) +-- > convolve1x5 kernel (a,b,c,d,e) -- > = Prelude.sum $ Prelude.zipWith (*) kernel [a,b,c,d,e] -- > -- > gaussian = [0.06136,0.24477,0.38774,0.24477,0.06136] diff --git a/src/Data/Array/Accelerate/Representation/Stencil.hs b/src/Data/Array/Accelerate/Representation/Stencil.hs index dd546721c..3c38652c2 100644 --- a/src/Data/Array/Accelerate/Representation/Stencil.hs +++ b/src/Data/Array/Accelerate/Representation/Stencil.hs @@ -32,11 +32,14 @@ import Language.Haskell.TH.Extra -- | GADT reifying the 'Stencil' class -- data StencilR sh e pat where + StencilRunit1 :: TypeR e -> StencilR DIM1 e (Tup1 e) StencilRunit3 :: TypeR e -> StencilR DIM1 e (Tup3 e e e) StencilRunit5 :: TypeR e -> StencilR DIM1 e (Tup5 e e e e e) StencilRunit7 :: TypeR e -> StencilR DIM1 e (Tup7 e e e e e e e) StencilRunit9 :: TypeR e -> StencilR DIM1 e (Tup9 e e e e e e e e e) + StencilRtup1 :: StencilR sh e pat1 + -> StencilR (sh, Int) e (Tup1 pat1) StencilRtup3 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 @@ -70,30 +73,36 @@ data StencilR sh e pat where -> StencilR (sh, Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) stencilEltR :: StencilR sh e pat -> TypeR e +stencilEltR (StencilRunit1 t) = t stencilEltR (StencilRunit3 t) = t stencilEltR (StencilRunit5 t) = t stencilEltR (StencilRunit7 t) = t stencilEltR (StencilRunit9 t) = t +stencilEltR (StencilRtup1 sR) = stencilEltR sR stencilEltR (StencilRtup3 sR _ _) = stencilEltR sR stencilEltR (StencilRtup5 sR _ _ _ _) = stencilEltR sR stencilEltR (StencilRtup7 sR _ _ _ _ _ _) = stencilEltR sR stencilEltR (StencilRtup9 sR _ _ _ _ _ _ _ _) = stencilEltR sR stencilShapeR :: StencilR sh e pat -> ShapeR sh +stencilShapeR (StencilRunit1 _) = ShapeRsnoc ShapeRz stencilShapeR (StencilRunit3 _) = ShapeRsnoc ShapeRz stencilShapeR (StencilRunit5 _) = ShapeRsnoc ShapeRz stencilShapeR (StencilRunit7 _) = ShapeRsnoc ShapeRz stencilShapeR (StencilRunit9 _) = ShapeRsnoc ShapeRz +stencilShapeR (StencilRtup1 sR) = ShapeRsnoc $ stencilShapeR sR stencilShapeR (StencilRtup3 sR _ _) = ShapeRsnoc $ stencilShapeR sR stencilShapeR (StencilRtup5 sR _ _ _ _) = ShapeRsnoc $ stencilShapeR sR stencilShapeR (StencilRtup7 sR _ _ _ _ _ _) = ShapeRsnoc $ stencilShapeR sR stencilShapeR (StencilRtup9 sR _ _ _ _ _ _ _ _) = ShapeRsnoc $ stencilShapeR sR stencilR :: StencilR sh e pat -> TypeR pat +stencilR (StencilRunit1 t) = tupR1 t stencilR (StencilRunit3 t) = tupR3 t t t stencilR (StencilRunit5 t) = tupR5 t t t t t stencilR (StencilRunit7 t) = tupR7 t t t t t t t stencilR (StencilRunit9 t) = tupR9 t t t t t t t t t +stencilR (StencilRtup1 s1) = tupR1 (stencilR s1) stencilR (StencilRtup3 s1 s2 s3) = tupR3 (stencilR s1) (stencilR s2) (stencilR s3) stencilR (StencilRtup5 s1 s2 s3 s4 s5) = tupR5 (stencilR s1) (stencilR s2) (stencilR s3) (stencilR s4) (stencilR s5) stencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) = tupR7 (stencilR s1) (stencilR s2) (stencilR s3) (stencilR s4) (stencilR s5) (stencilR s6) (stencilR s7) @@ -106,11 +115,14 @@ stencilHalo :: StencilR sh e stencil -> (ShapeR sh, sh) stencilHalo = go' where go' :: StencilR sh e stencil -> (ShapeR sh, sh) + go' StencilRunit1{} = (dim1, ((), 0)) go' StencilRunit3{} = (dim1, ((), 1)) go' StencilRunit5{} = (dim1, ((), 2)) go' StencilRunit7{} = (dim1, ((), 3)) go' StencilRunit9{} = (dim1, ((), 4)) -- + go' (StencilRtup1 a ) = (ShapeRsnoc shR, cons shR 1 $ foldl1 (union shR) [a']) + where (shR, a') = go' a go' (StencilRtup3 a b c ) = (ShapeRsnoc shR, cons shR 1 $ foldl1 (union shR) [a', go b, go c]) where (shR, a') = go' a go' (StencilRtup5 a b c d e ) = (ShapeRsnoc shR, cons shR 2 $ foldl1 (union shR) [a', go b, go c, go d, go e]) @@ -127,6 +139,9 @@ stencilHalo = go' cons ShapeRz ix () = ((), ix) cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) +tupR1 :: TupR s t1 -> TupR s (Tup1 t1) +tupR1 t1 = TupRunit `TupRpair` t1 + tupR3 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s (Tup3 t1 t2 t3) tupR3 t1 t2 t3 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3 @@ -140,20 +155,24 @@ tupR9 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s t4 -> TupR s t5 -> TupR s tupR9 t1 t2 t3 t4 t5 t6 t7 t8 t9 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3 `TupRpair` t4 `TupRpair` t5 `TupRpair` t6 `TupRpair` t7 `TupRpair` t8 `TupRpair` t9 rnfStencilR :: StencilR sh e pat -> () +rnfStencilR (StencilRunit1 t) = rnfTypeR t rnfStencilR (StencilRunit3 t) = rnfTypeR t rnfStencilR (StencilRunit5 t) = rnfTypeR t rnfStencilR (StencilRunit7 t) = rnfTypeR t rnfStencilR (StencilRunit9 t) = rnfTypeR t +rnfStencilR (StencilRtup1 s1) = rnfStencilR s1 rnfStencilR (StencilRtup3 s1 s2 s3) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 rnfStencilR (StencilRtup5 s1 s2 s3 s4 s5) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 rnfStencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 `seq` rnfStencilR s6 `seq` rnfStencilR s7 rnfStencilR (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 `seq` rnfStencilR s6 `seq` rnfStencilR s7 `seq` rnfStencilR s8 `seq` rnfStencilR s9 liftStencilR :: StencilR sh e pat -> CodeQ (StencilR sh e pat) +liftStencilR (StencilRunit1 tp) = [|| StencilRunit1 $$(liftTypeR tp) ||] liftStencilR (StencilRunit3 tp) = [|| StencilRunit3 $$(liftTypeR tp) ||] liftStencilR (StencilRunit5 tp) = [|| StencilRunit5 $$(liftTypeR tp) ||] liftStencilR (StencilRunit7 tp) = [|| StencilRunit7 $$(liftTypeR tp) ||] liftStencilR (StencilRunit9 tp) = [|| StencilRunit9 $$(liftTypeR tp) ||] +liftStencilR (StencilRtup1 s1) = [|| StencilRtup1 $$(liftStencilR s1) ||] liftStencilR (StencilRtup3 s1 s2 s3) = [|| StencilRtup3 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) ||] liftStencilR (StencilRtup5 s1 s2 s3 s4 s5) = [|| StencilRtup5 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) ||] liftStencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) = [|| StencilRtup7 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) $$(liftStencilR s6) $$(liftStencilR s7) ||] diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index 477f09a00..c90e7ef48 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -123,5 +123,5 @@ runQ $ in tySynD (mkName ("Tup" ++ show n)) (map plainTV xs) rhs in - mapM mkT [2..16] + mapM mkT [0..16] diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 8fa577f41..9418a63fc 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -625,6 +625,11 @@ class Stencil sh e stencil where stencilPrj :: SmartExp (StencilR sh stencil) -> stencil -- DIM1 +instance Elt e => Stencil Sugar.DIM1 e (Exp e) where + type StencilR Sugar.DIM1 (Exp e) = ((), EltR e) + stencilR = StencilRunit1 @(EltR e) $ eltR @e + stencilPrj s = Exp $ prj0 s + instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e) = EltR (e, e, e) From 6658b101a427d6a527d101e0fd21cb4d01c034ad Mon Sep 17 00:00:00 2001 From: David Date: Thu, 25 Nov 2021 22:31:00 +0100 Subject: [PATCH 2/3] Use Solo --- src/Data/Array/Accelerate.hs | 3 ++- src/Data/Array/Accelerate/Language.hs | 10 +++++----- src/Data/Array/Accelerate/Smart.hs | 12 +++++++++--- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 19476fc80..d079d3bbf 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -285,6 +285,7 @@ module Data.Array.Accelerate ( clamp, mirror, wrap, function, -- *** Common stencil patterns + Solo(..), Stencil1, Stencil3, Stencil5, Stencil7, Stencil9, Stencil1x1, Stencil1x3, Stencil1x5, Stencil3x1, Stencil3x3, Stencil3x5, @@ -433,7 +434,7 @@ module Data.Array.Accelerate ( CChar, CSChar, CUChar, ) where - +import GHC.Tuple ( Solo(..) ) import Data.Array.Accelerate.Classes.Bounded import Data.Array.Accelerate.Classes.Enum import Data.Array.Accelerate.Classes.Eq diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index eb32ffeac..d850ddfd9 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -117,7 +117,7 @@ import Data.Array.Accelerate.Classes.Fractional import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord - +import GHC.Tuple import Prelude ( ($), (.), Maybe(..), Char ) @@ -851,16 +851,16 @@ backpermute = Acc $$$ applyAcc (Backpermute $ shapeR @sh') -- -- DIM1 stencil type -type Stencil1 a = (Exp a) +type Stencil1 a = Solo (Exp a) type Stencil3 a = (Exp a, Exp a, Exp a) type Stencil5 a = (Exp a, Exp a, Exp a, Exp a, Exp a) type Stencil7 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a) type Stencil9 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a) -- DIM2 stencil type -type Stencil1x1 a = (Stencil1 a) -type Stencil3x1 a = (Stencil3 a) -type Stencil5x1 a = (Stencil5 a) +type Stencil1x1 a = Solo (Stencil1 a) +type Stencil3x1 a = Solo (Stencil3 a) +type Stencil5x1 a = Solo (Stencil5 a) type Stencil1x3 a = (Stencil1 a, Stencil1 a, Stencil1 a) type Stencil3x3 a = (Stencil3 a, Stencil3 a, Stencil3 a) type Stencil5x3 a = (Stencil5 a, Stencil5 a, Stencil5 a) diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 9418a63fc..86e9eac64 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -113,6 +113,7 @@ import Data.Text.Lazy.Builder import Formatting import GHC.TypeLits +import GHC.Tuple -- Array computations @@ -625,10 +626,10 @@ class Stencil sh e stencil where stencilPrj :: SmartExp (StencilR sh stencil) -> stencil -- DIM1 -instance Elt e => Stencil Sugar.DIM1 e (Exp e) where - type StencilR Sugar.DIM1 (Exp e) = ((), EltR e) +instance Elt e => Stencil Sugar.DIM1 e (Solo (Exp e)) where + type StencilR Sugar.DIM1 (Solo (Exp e)) = ((), EltR e) stencilR = StencilRunit1 @(EltR e) $ eltR @e - stencilPrj s = Exp $ prj0 s + stencilPrj s = Solo (Exp $ prj0 s) instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e) @@ -676,6 +677,11 @@ instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e Exp $ prj0 s) -- DIM(n+1) +instance Stencil (sh:.Int) a row => Stencil (sh:.Int:.Int) a (Solo row) where + type StencilR (sh:.Int:.Int) (Solo row) = Tup1 (StencilR (sh:.Int) row) + stencilR = StencilRtup1 (stencilR @(sh:.Int) @a @row) + stencilPrj s = Solo (stencilPrj @(sh:.Int) @a $ prj0 s) + instance (Stencil (sh:.Int) a row2, Stencil (sh:.Int) a row1, Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row2, row1, row0) where From a3f2bc2f25ee87551cf3081468615b7003b7ba8c Mon Sep 17 00:00:00 2001 From: David Date: Thu, 9 Dec 2021 17:12:52 +0100 Subject: [PATCH 3/3] Self-roll Unary (because GHC.Tuple is unstable), add tests for unary stencil dimensions, use Sugar.Stencil instead of Smart.hs.Stencil --- accelerate.cabal | 1 + src/Data/Array/Accelerate.hs | 4 +- src/Data/Array/Accelerate/Language.hs | 10 +- src/Data/Array/Accelerate/Pattern.hs | 18 ++ src/Data/Array/Accelerate/Smart.hs | 190 ------------- src/Data/Array/Accelerate/Sugar/Stencil.hs | 265 ++++++++++++------ .../Accelerate/Test/NoFib/Prelude/Stencil.hs | 159 ++++++++++- src/Data/Array/Accelerate/Trafo/Sharing.hs | 2 +- 8 files changed, 367 insertions(+), 282 deletions(-) diff --git a/accelerate.cabal b/accelerate.cabal index 445d76a1b..124685d67 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -408,6 +408,7 @@ library Data.Array.Accelerate.Sugar.Elt Data.Array.Accelerate.Sugar.Foreign Data.Array.Accelerate.Sugar.Shape + Data.Array.Accelerate.Sugar.Stencil Data.Array.Accelerate.Sugar.Tag Data.Array.Accelerate.Sugar.Vec Data.Array.Accelerate.Trafo diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index d079d3bbf..a064a6eca 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -285,7 +285,7 @@ module Data.Array.Accelerate ( clamp, mirror, wrap, function, -- *** Common stencil patterns - Solo(..), + Unary(..), Stencil1, Stencil3, Stencil5, Stencil7, Stencil9, Stencil1x1, Stencil1x3, Stencil1x5, Stencil3x1, Stencil3x3, Stencil3x5, @@ -347,6 +347,7 @@ module Data.Array.Accelerate ( -- $pattern_synonyms -- pattern Pattern, + pattern T1, pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, @@ -434,7 +435,6 @@ module Data.Array.Accelerate ( CChar, CSChar, CUChar, ) where -import GHC.Tuple ( Solo(..) ) import Data.Array.Accelerate.Classes.Bounded import Data.Array.Accelerate.Classes.Enum import Data.Array.Accelerate.Classes.Eq diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index d850ddfd9..e1547a2dc 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -117,8 +117,8 @@ import Data.Array.Accelerate.Classes.Fractional import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord -import GHC.Tuple import Prelude ( ($), (.), Maybe(..), Char ) +import Data.Array.Accelerate.Sugar.Stencil -- $setup @@ -851,16 +851,16 @@ backpermute = Acc $$$ applyAcc (Backpermute $ shapeR @sh') -- -- DIM1 stencil type -type Stencil1 a = Solo (Exp a) +type Stencil1 a = Unary (Exp a) type Stencil3 a = (Exp a, Exp a, Exp a) type Stencil5 a = (Exp a, Exp a, Exp a, Exp a, Exp a) type Stencil7 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a) type Stencil9 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a) -- DIM2 stencil type -type Stencil1x1 a = Solo (Stencil1 a) -type Stencil3x1 a = Solo (Stencil3 a) -type Stencil5x1 a = Solo (Stencil5 a) +type Stencil1x1 a = Unary (Stencil1 a) +type Stencil3x1 a = Unary (Stencil3 a) +type Stencil5x1 a = Unary (Stencil5 a) type Stencil1x3 a = (Stencil1 a, Stencil1 a, Stencil1 a) type Stencil3x3 a = (Stencil3 a, Stencil3 a, Stencil3 a) type Stencil5x3 a = (Stencil5 a, Stencil5 a, Stencil5 a) diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index e212c0869..3bf2a3ca7 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -26,6 +26,7 @@ module Data.Array.Accelerate.Pattern ( pattern Pattern, + pattern T1, pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, @@ -36,6 +37,7 @@ module Data.Array.Accelerate.Pattern ( pattern V2, pattern V3, pattern V4, pattern V8, pattern V16, + Unary (..) ) where import Data.Array.Accelerate.AST.Idx @@ -293,3 +295,19 @@ runQ $ do vs <- mapM mkV [2,3,4,8,16] return $ concat (ts ++ is ++ vs) +newtype Unary a = Unary {runUnary :: a} +instance Elt a => Elt (Unary a) where + type EltR (Unary a) = EltR a + eltR = eltR @a + tagsR = tagsR @a + fromElt = fromElt . runUnary + toElt = Unary . toElt + +pattern T1 :: Elt a => Exp a -> Exp (Unary a) +pattern T1 a = Pattern (Unary a) +{-# COMPLETE T1 #-} + + +instance Elt a => IsPattern Exp (Unary a) (Unary (Exp a)) where + builder (Unary (Exp a)) = Exp a + matcher (Exp t) = Unary $ Exp t diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 86e9eac64..a3806c920 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -36,7 +36,6 @@ module Data.Array.Accelerate.Smart ( -- ** Scalar expressions Exp(..), SmartExp(..), PreSmartExp(..), - Stencil(..), Boundary(..), PreBoundary(..), PrimBool, PrimMaybe, @@ -100,7 +99,6 @@ import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) ) import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Representation.Stencil as R import qualified Data.Array.Accelerate.Sugar.Array as Sugar -import qualified Data.Array.Accelerate.Sugar.Shape as Sugar import Data.Array.Accelerate.AST ( Direction(..), Message(..) , PrimBool, PrimMaybe @@ -108,12 +106,10 @@ import Data.Array.Accelerate.AST ( Direction( , PrimConst(..), primConstType ) import Data.Primitive.Vec -import Data.Kind import Data.Text.Lazy.Builder import Formatting import GHC.TypeLits -import GHC.Tuple -- Array computations @@ -587,10 +583,6 @@ data PreSmartExp acc exp t where -> exp a -> PreSmartExp acc exp b - --- Smart constructors for stencils --- ------------------------------- - -- | Boundary condition specification for stencil operations -- data Boundary t where @@ -608,188 +600,6 @@ data PreBoundary acc exp t where Function :: (SmartExp sh -> exp e) -> PreBoundary acc exp (Array sh e) - --- Stencil reification --- ------------------- --- --- In the AST representation, we turn the stencil type from nested tuples --- of Accelerate expressions into an Accelerate expression whose type is --- a tuple nested in the same manner. This enables us to represent the --- stencil function as a unary function (which also only needs one de --- Bruijn index). The various positions in the stencil are accessed via --- tuple indices (i.e., projections). --- -class Stencil sh e stencil where - type StencilR sh stencil :: Type - - stencilR :: R.StencilR (EltR sh) (EltR e) (StencilR sh stencil) - stencilPrj :: SmartExp (StencilR sh stencil) -> stencil - --- DIM1 -instance Elt e => Stencil Sugar.DIM1 e (Solo (Exp e)) where - type StencilR Sugar.DIM1 (Solo (Exp e)) = ((), EltR e) - stencilR = StencilRunit1 @(EltR e) $ eltR @e - stencilPrj s = Solo (Exp $ prj0 s) - -instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where - type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e) - = EltR (e, e, e) - stencilR = StencilRunit3 @(EltR e) $ eltR @e - stencilPrj s = (Exp $ prj2 s, - Exp $ prj1 s, - Exp $ prj0 s) - -instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where - type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e) - stencilR = StencilRunit5 $ eltR @e - stencilPrj s = (Exp $ prj4 s, - Exp $ prj3 s, - Exp $ prj2 s, - Exp $ prj1 s, - Exp $ prj0 s) - -instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where - type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e, e, e) - stencilR = StencilRunit7 $ eltR @e - stencilPrj s = (Exp $ prj6 s, - Exp $ prj5 s, - Exp $ prj4 s, - Exp $ prj3 s, - Exp $ prj2 s, - Exp $ prj1 s, - Exp $ prj0 s) - -instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - where - type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e, e, e, e, e) - stencilR = StencilRunit9 $ eltR @e - stencilPrj s = (Exp $ prj8 s, - Exp $ prj7 s, - Exp $ prj6 s, - Exp $ prj5 s, - Exp $ prj4 s, - Exp $ prj3 s, - Exp $ prj2 s, - Exp $ prj1 s, - Exp $ prj0 s) - --- DIM(n+1) -instance Stencil (sh:.Int) a row => Stencil (sh:.Int:.Int) a (Solo row) where - type StencilR (sh:.Int:.Int) (Solo row) = Tup1 (StencilR (sh:.Int) row) - stencilR = StencilRtup1 (stencilR @(sh:.Int) @a @row) - stencilPrj s = Solo (stencilPrj @(sh:.Int) @a $ prj0 s) - -instance (Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row2, row1, row0) - = Tup3 (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) - stencilR = StencilRtup3 (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) - stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj2 s, - stencilPrj @(sh:.Int) @a $ prj1 s, - stencilPrj @(sh:.Int) @a $ prj0 s) - -instance (Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row4, row3, row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row4, row3, row2, row1, row0) - = Tup5 (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) - (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) - stencilR = StencilRtup5 (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) - (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) - stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj4 s, - stencilPrj @(sh:.Int) @a $ prj3 s, - stencilPrj @(sh:.Int) @a $ prj2 s, - stencilPrj @(sh:.Int) @a $ prj1 s, - stencilPrj @(sh:.Int) @a $ prj0 s) - -instance (Stencil (sh:.Int) a row6, - Stencil (sh:.Int) a row5, - Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row0) - => Stencil (sh:.Int:.Int) a (row6, row5, row4, row3, row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row6, row5, row4, row3, row2, row1, row0) - = Tup7 (StencilR (sh:.Int) row6) (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) - (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) - (StencilR (sh:.Int) row0) - stencilR = StencilRtup7 (stencilR @(sh:.Int) @a @row6) - (stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) - (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) - stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj6 s, - stencilPrj @(sh:.Int) @a $ prj5 s, - stencilPrj @(sh:.Int) @a $ prj4 s, - stencilPrj @(sh:.Int) @a $ prj3 s, - stencilPrj @(sh:.Int) @a $ prj2 s, - stencilPrj @(sh:.Int) @a $ prj1 s, - stencilPrj @(sh:.Int) @a $ prj0 s) - -instance (Stencil (sh:.Int) a row8, - Stencil (sh:.Int) a row7, - Stencil (sh:.Int) a row6, - Stencil (sh:.Int) a row5, - Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row0) - => Stencil (sh:.Int:.Int) a (row8, row7, row6, row5, row4, row3, row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row8, row7, row6, row5, row4, row3, row2, row1, row0) - = Tup9 (StencilR (sh:.Int) row8) (StencilR (sh:.Int) row7) (StencilR (sh:.Int) row6) - (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) - (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) - stencilR = StencilRtup9 - (stencilR @(sh:.Int) @a @row8) (stencilR @(sh:.Int) @a @row7) (stencilR @(sh:.Int) @a @row6) - (stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) - (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) - stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj8 s, - stencilPrj @(sh:.Int) @a $ prj7 s, - stencilPrj @(sh:.Int) @a $ prj6 s, - stencilPrj @(sh:.Int) @a $ prj5 s, - stencilPrj @(sh:.Int) @a $ prj4 s, - stencilPrj @(sh:.Int) @a $ prj3 s, - stencilPrj @(sh:.Int) @a $ prj2 s, - stencilPrj @(sh:.Int) @a $ prj1 s, - stencilPrj @(sh:.Int) @a $ prj0 s) - -prjTail :: SmartExp (t, a) -> SmartExp t -prjTail = SmartExp . Prj PairIdxLeft - -prj0 :: SmartExp (t, a) -> SmartExp a -prj0 = SmartExp . Prj PairIdxRight - -prj1 :: SmartExp ((t, a), s0) -> SmartExp a -prj1 = prj0 . prjTail - -prj2 :: SmartExp (((t, a), s1), s0) -> SmartExp a -prj2 = prj1 . prjTail - -prj3 :: SmartExp ((((t, a), s2), s1), s0) -> SmartExp a -prj3 = prj2 . prjTail - -prj4 :: SmartExp (((((t, a), s3), s2), s1), s0) -> SmartExp a -prj4 = prj3 . prjTail - -prj5 :: SmartExp ((((((t, a), s4), s3), s2), s1), s0) -> SmartExp a -prj5 = prj4 . prjTail - -prj6 :: SmartExp (((((((t, a), s5), s4), s3), s2), s1), s0) -> SmartExp a -prj6 = prj5 . prjTail - -prj7 :: SmartExp ((((((((t, a), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a -prj7 = prj6 . prjTail - -prj8 :: SmartExp (((((((((t, a), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a -prj8 = prj7 . prjTail - - -- Extracting type information -- --------------------------- diff --git a/src/Data/Array/Accelerate/Sugar/Stencil.hs b/src/Data/Array/Accelerate/Sugar/Stencil.hs index 15039c320..0058c696e 100644 --- a/src/Data/Array/Accelerate/Sugar/Stencil.hs +++ b/src/Data/Array/Accelerate/Sugar/Stencil.hs @@ -25,94 +25,197 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Representation.Type import qualified Data.Array.Accelerate.Representation.Stencil as R +import Data.Array.Accelerate.Representation.Stencil hiding ( StencilR, stencilR ) import Data.Kind +import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Smart +import qualified Data.Array.Accelerate.Sugar.Shape as Sugar +import Data.Array.Accelerate.AST.Idx --- Reification of the stencil type from nested tuples of Accelerate --- expressions in the surface language. This enables us to represent the --- stencil function as a unary function. +-- Stencil reification +-- ------------------- +-- +-- In the AST representation, we turn the stencil type from nested tuples +-- of Accelerate expressions into an Accelerate expression whose type is +-- a tuple nested in the same manner. This enables us to represent the +-- stencil function as a unary function (which also only needs one de +-- Bruijn index). The various positions in the stencil are accessed via +-- tuple indices (i.e., projections). -- class Stencil sh e stencil where type StencilR sh stencil :: Type - stencilR :: R.StencilR (EltR sh) (EltR e) (StencilR sh stencil) - -instance Elt e => Stencil DIM1 e (exp e, exp e, exp e) where - type StencilR DIM1 (exp e, exp e, exp e) = EltR (e, e, e) - stencilR = R.StencilRunit3 $ eltR @e - -instance Elt e => Stencil DIM1 e (exp e, exp e, exp e, exp e, exp e) where - type StencilR DIM1 (exp e, exp e, exp e, exp e, exp e) = - EltR (e, e, e, e, e) - stencilR = R.StencilRunit5 $ eltR @e - -instance Elt e => Stencil DIM1 e (exp e, exp e, exp e, exp e, exp e, exp e, exp e) where - type StencilR DIM1 (exp e, exp e, exp e, exp e, exp e, exp e, exp e) = - EltR (e, e, e, e, e, e, e) - stencilR = R.StencilRunit7 $ eltR @e - -instance Elt e => Stencil DIM1 e (exp e, exp e, exp e, exp e, exp e, exp e, exp e, exp e, exp e) where - type StencilR DIM1 (exp e, exp e, exp e, exp e, exp e, exp e, exp e, exp e, exp e) = - EltR (e, e, e, e, e, e, e, e, e) - stencilR = R.StencilRunit9 $ eltR @e - -instance ( Stencil (sh:.Int) a row2 - , Stencil (sh:.Int) a row1 - , Stencil (sh:.Int) a row0 - ) - => Stencil (sh:.Int:.Int) a (row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row2, row1, row0) = - Tup3 (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) - stencilR = R.StencilRtup3 (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) - -instance ( Stencil (sh:.Int) a row4 - , Stencil (sh:.Int) a row3 - , Stencil (sh:.Int) a row2 - , Stencil (sh:.Int) a row1 - , Stencil (sh:.Int) a row0 - ) - => Stencil (sh:.Int:.Int) a (row4, row3, row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row4, row3, row2, row1, row0) = - Tup5 (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) - (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) - stencilR = R.StencilRtup5 - (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) (stencilR @(sh:.Int) @a @row2) - (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) - -instance ( Stencil (sh:.Int) a row6 - , Stencil (sh:.Int) a row5 - , Stencil (sh:.Int) a row4 - , Stencil (sh:.Int) a row3 - , Stencil (sh:.Int) a row2 - , Stencil (sh:.Int) a row1 - , Stencil (sh:.Int) a row0 - ) - => Stencil (sh:.Int:.Int) a (row6, row5, row4, row3, row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row6, row5, row4, row3, row2, row1, row0) = - Tup7 (StencilR (sh:.Int) row6) (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) - (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) - (StencilR (sh:.Int) row0) - stencilR = R.StencilRtup7 - (stencilR @(sh:.Int) @a @row6) (stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) - (stencilR @(sh:.Int) @a @row3) (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) - (stencilR @(sh:.Int) @a @row0) - -instance ( Stencil (sh:.Int) a row8 - , Stencil (sh:.Int) a row7 - , Stencil (sh:.Int) a row6 - , Stencil (sh:.Int) a row5 - , Stencil (sh:.Int) a row4 - , Stencil (sh:.Int) a row3 - , Stencil (sh:.Int) a row2 - , Stencil (sh:.Int) a row1 - , Stencil (sh:.Int) a row0 - ) - => Stencil (sh:.Int:.Int) a (row8, row7, row6, row5, row4, row3, row2, row1, row0) where - type StencilR (sh:.Int:.Int) (row8, row7, row6, row5, row4, row3, row2, row1, row0) = - Tup9 (StencilR (sh:.Int) row8) (StencilR (sh:.Int) row7) (StencilR (sh:.Int) row6) - (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) - (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) - stencilR = R.StencilRtup9 + + stencilR :: R.StencilR (EltR sh) (EltR e) (StencilR sh stencil) + stencilPrj :: SmartExp (StencilR sh stencil) -> stencil + +-- DIM1 +instance Elt e => Stencil Sugar.DIM1 e (Unary (Exp e)) where + type StencilR Sugar.DIM1 (Unary (Exp e)) = ((), EltR e) + stencilR = StencilRunit1 @(EltR e) $ eltR @e + stencilPrj s = Unary (Exp $ prj0 s) + +instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where + type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e) + = EltR (e, e, e) + stencilR = StencilRunit3 @(EltR e) $ eltR @e + stencilPrj s = (Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) + +instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where + type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e) + = EltR (e, e, e, e, e) + stencilR = StencilRunit5 $ eltR @e + stencilPrj s = (Exp $ prj4 s, + Exp $ prj3 s, + Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) + +instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where + type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) + = EltR (e, e, e, e, e, e, e) + stencilR = StencilRunit7 $ eltR @e + stencilPrj s = (Exp $ prj6 s, + Exp $ prj5 s, + Exp $ prj4 s, + Exp $ prj3 s, + Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) + +instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) + where + type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) + = EltR (e, e, e, e, e, e, e, e, e) + stencilR = StencilRunit9 $ eltR @e + stencilPrj s = (Exp $ prj8 s, + Exp $ prj7 s, + Exp $ prj6 s, + Exp $ prj5 s, + Exp $ prj4 s, + Exp $ prj3 s, + Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) + +-- DIM(n+1) +instance Stencil (sh:.Int) a row => Stencil (sh:.Int:.Int) a (Unary row) where + type StencilR (sh:.Int:.Int) (Unary row) = Tup1 (StencilR (sh:.Int) row) + stencilR = StencilRtup1 (stencilR @(sh:.Int) @a @row) + stencilPrj s = Unary (stencilPrj @(sh:.Int) @a $ prj0 s) + +instance (Stencil (sh:.Int) a row2, + Stencil (sh:.Int) a row1, + Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row2, row1, row0) where + type StencilR (sh:.Int:.Int) (row2, row1, row0) + = Tup3 (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) + stencilR = StencilRtup3 (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) + +instance (Stencil (sh:.Int) a row4, + Stencil (sh:.Int) a row3, + Stencil (sh:.Int) a row2, + Stencil (sh:.Int) a row1, + Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row4, row3, row2, row1, row0) where + type StencilR (sh:.Int:.Int) (row4, row3, row2, row1, row0) + = Tup5 (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) + (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) + stencilR = StencilRtup5 (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) + (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj4 s, + stencilPrj @(sh:.Int) @a $ prj3 s, + stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) + +instance (Stencil (sh:.Int) a row6, + Stencil (sh:.Int) a row5, + Stencil (sh:.Int) a row4, + Stencil (sh:.Int) a row3, + Stencil (sh:.Int) a row2, + Stencil (sh:.Int) a row1, + Stencil (sh:.Int) a row0) + => Stencil (sh:.Int:.Int) a (row6, row5, row4, row3, row2, row1, row0) where + type StencilR (sh:.Int:.Int) (row6, row5, row4, row3, row2, row1, row0) + = Tup7 (StencilR (sh:.Int) row6) (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) + (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) + (StencilR (sh:.Int) row0) + stencilR = StencilRtup7 (stencilR @(sh:.Int) @a @row6) + (stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) + (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj6 s, + stencilPrj @(sh:.Int) @a $ prj5 s, + stencilPrj @(sh:.Int) @a $ prj4 s, + stencilPrj @(sh:.Int) @a $ prj3 s, + stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) + +instance (Stencil (sh:.Int) a row8, + Stencil (sh:.Int) a row7, + Stencil (sh:.Int) a row6, + Stencil (sh:.Int) a row5, + Stencil (sh:.Int) a row4, + Stencil (sh:.Int) a row3, + Stencil (sh:.Int) a row2, + Stencil (sh:.Int) a row1, + Stencil (sh:.Int) a row0) + => Stencil (sh:.Int:.Int) a (row8, row7, row6, row5, row4, row3, row2, row1, row0) where + type StencilR (sh:.Int:.Int) (row8, row7, row6, row5, row4, row3, row2, row1, row0) + = Tup9 (StencilR (sh:.Int) row8) (StencilR (sh:.Int) row7) (StencilR (sh:.Int) row6) + (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) + (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0) + stencilR = StencilRtup9 (stencilR @(sh:.Int) @a @row8) (stencilR @(sh:.Int) @a @row7) (stencilR @(sh:.Int) @a @row6) (stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj8 s, + stencilPrj @(sh:.Int) @a $ prj7 s, + stencilPrj @(sh:.Int) @a $ prj6 s, + stencilPrj @(sh:.Int) @a $ prj5 s, + stencilPrj @(sh:.Int) @a $ prj4 s, + stencilPrj @(sh:.Int) @a $ prj3 s, + stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) + +-- Smart constructors for stencils +-- ------------------------------- + + + + +prjTail :: SmartExp (t, a) -> SmartExp t +prjTail = SmartExp . Prj PairIdxLeft + +prj0 :: SmartExp (t, a) -> SmartExp a +prj0 = SmartExp . Prj PairIdxRight + +prj1 :: SmartExp ((t, a), s0) -> SmartExp a +prj1 = prj0 . prjTail + +prj2 :: SmartExp (((t, a), s1), s0) -> SmartExp a +prj2 = prj1 . prjTail + +prj3 :: SmartExp ((((t, a), s2), s1), s0) -> SmartExp a +prj3 = prj2 . prjTail + +prj4 :: SmartExp (((((t, a), s3), s2), s1), s0) -> SmartExp a +prj4 = prj3 . prjTail + +prj5 :: SmartExp ((((((t, a), s4), s3), s2), s1), s0) -> SmartExp a +prj5 = prj4 . prjTail + +prj6 :: SmartExp (((((((t, a), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj6 = prj5 . prjTail + +prj7 :: SmartExp ((((((((t, a), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj7 = prj6 . prjTail + +prj8 :: SmartExp (((((((((t, a), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj8 = prj7 . prjTail diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs index dec121c68..d5bc8f5d3 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs @@ -32,6 +32,7 @@ import Data.Array.Accelerate.Sugar.Elt as S import Data.Array.Accelerate.Sugar.Array as S import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Analysis.Match +import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Type import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config @@ -75,7 +76,8 @@ test_stencil runN = testDim1 :: TestTree testDim1 = testGroup "DIM1" - [ testProperty "stencil3" $ test_stencil3 runN e + [ testProperty "stencil1" $ test_stencil1 runN e + , testProperty "stencil3" $ test_stencil3 runN e , testProperty "stencil5" $ test_stencil5 runN e , testProperty "stencil7" $ test_stencil7 runN e , testProperty "stencil9" $ test_stencil9 runN e @@ -84,7 +86,9 @@ test_stencil runN = testDim2 :: TestTree testDim2 = testGroup "DIM2" - [ testProperty "stencil3x3" $ test_stencil3x3 runN e + [ testProperty "stencil1x3" $ test_stencil1x3 runN e + , testProperty "stencil3x1" $ test_stencil3x1 runN e + , testProperty "stencil3x3" $ test_stencil3x3 runN e , testProperty "stencil5x5" $ test_stencil5x5 runN e , testProperty "stencil7x7" $ test_stencil7x7 runN e , testProperty "stencil9x9" $ test_stencil9x9 runN e @@ -97,6 +101,25 @@ test_stencil runN = ] +test_stencil1 + :: (P.Num e, A.Num e, Similar e, Show e) + => RunN + -> Gen e + -> Property +test_stencil1 runN e = + property $ do + sh <- forAll ((Z :.) P.<$> Gen.int (Range.linear 2 256)) + xs <- forAll (array sh e) + b <- forAll (boundary e) + P3 _ a r <- forAll pattern3 + let !go = case b of + Clamp -> runN (A.stencil a A.clamp) + Wrap -> runN (A.stencil a A.wrap) + Mirror -> runN (A.stencil a A.mirror) + Constant x -> runN (A.stencil a (A.function (\_ -> constant x))) + -- + go xs ~~~ stencil3Ref r b xs + test_stencil3 :: (P.Num e, A.Num e, Similar e, Show e) => RunN @@ -174,6 +197,52 @@ test_stencil9 runN e = go xs ~~~ stencil9Ref r b xs +test_stencil1x3 + :: (P.Num e, A.Num e, Similar e, Show e) + => RunN + -> Gen e + -> Property +test_stencil1x3 runN e = + property $ do + sy <- forAll (Gen.int (Range.linear 2 96)) + sx <- forAll (Gen.int (Range.linear 2 96)) + let sh = Z :. sy :. sx + xs <- forAll (array sh e) + b <- forAll (boundary e) + P1x3 _ a r <- forAll pattern1x3 + let !go = case b of + Clamp -> runN (A.stencil a A.clamp) + Wrap -> runN (A.stencil a A.wrap) + Mirror -> runN (A.stencil a A.mirror) + Constant x -> runN (A.stencil a (A.function (\_ -> constant x))) + -- + go xs ~~~ stencil1x3Ref r b xs + + + +test_stencil3x1 + :: (P.Num e, A.Num e, Similar e, Show e) + => RunN + -> Gen e + -> Property +test_stencil3x1 runN e = + property $ do + sy <- forAll (Gen.int (Range.linear 2 96)) + sx <- forAll (Gen.int (Range.linear 2 96)) + let sh = Z :. sy :. sx + xs <- forAll (array sh e) + b <- forAll (boundary e) + P3x1 _ a r <- forAll pattern3x1 + let !go = case b of + Clamp -> runN (A.stencil a A.clamp) + Wrap -> runN (A.stencil a A.wrap) + Mirror -> runN (A.stencil a A.mirror) + Constant x -> runN (A.stencil a (A.function (\_ -> constant x))) + -- + go xs ~~~ stencil3x1Ref r b xs + + + test_stencil3x3 :: (P.Num e, A.Num e, Similar e, Show e) => RunN @@ -280,12 +349,14 @@ test_stencil3x3x3 runN e = -- go xs ~~~ stencil3x3x3Ref r b xs - +type Stencil1Ref a = Unary a type Stencil3Ref a = (a,a,a) type Stencil5Ref a = (a,a,a,a,a) type Stencil7Ref a = (a,a,a,a,a,a,a) type Stencil9Ref a = (a,a,a,a,a,a,a,a,a) +type Stencil3x1Ref a = Unary (Stencil3Ref a) +type Stencil1x3Ref a = Stencil3Ref (Unary a) type Stencil3x3Ref a = (Stencil3Ref a, Stencil3Ref a, Stencil3Ref a) type Stencil5x5Ref a = (Stencil5Ref a, Stencil5Ref a, Stencil5Ref a, Stencil5Ref a, Stencil5Ref a) type Stencil7x7Ref a = (Stencil7Ref a, Stencil7Ref a, Stencil7Ref a, Stencil7Ref a, Stencil7Ref a, Stencil7Ref a, Stencil7Ref a) @@ -313,11 +384,14 @@ boundary e = , pure Mirror ] +data Pattern1 a = P1 [Int] (Stencil1 a -> Exp a) (Stencil1Ref a -> a) data Pattern3 a = P3 [Int] (Stencil3 a -> Exp a) (Stencil3Ref a -> a) data Pattern5 a = P5 [Int] (Stencil5 a -> Exp a) (Stencil5Ref a -> a) data Pattern7 a = P7 [Int] (Stencil7 a -> Exp a) (Stencil7Ref a -> a) data Pattern9 a = P9 [Int] (Stencil9 a -> Exp a) (Stencil9Ref a -> a) +data Pattern1x3 a = P1x3 [[Int]] (Stencil1x3 a -> Exp a) (Stencil1x3Ref a -> a) +data Pattern3x1 a = P3x1 [[Int]] (Stencil3x1 a -> Exp a) (Stencil3x1Ref a -> a) data Pattern3x3 a = P3x3 [[Int]] (Stencil3x3 a -> Exp a) (Stencil3x3Ref a -> a) data Pattern5x5 a = P5x5 [[Int]] (Stencil5x5 a -> Exp a) (Stencil5x5Ref a -> a) data Pattern7x7 a = P7x7 [[Int]] (Stencil7x7 a -> Exp a) (Stencil7x7Ref a -> a) @@ -325,11 +399,14 @@ data Pattern9x9 a = P9x9 [[Int]] (Stencil9x9 a -> Exp a) (Stencil9x9Ref a -> a) data Pattern3x3x3 a = P3x3x3 [[[Int]]] (Stencil3x3x3 a -> Exp a) (Stencil3x3x3Ref a -> a) +instance Show (Pattern1 a) where show (P1 ix _ _) = show ix instance Show (Pattern3 a) where show (P3 ix _ _) = show ix instance Show (Pattern5 a) where show (P5 ix _ _) = show ix instance Show (Pattern7 a) where show (P7 ix _ _) = show ix instance Show (Pattern9 a) where show (P9 ix _ _) = show ix +instance Show (Pattern1x3 a) where show (P1x3 ix _ _) = show ix +instance Show (Pattern3x1 a) where show (P3x1 ix _ _) = show ix instance Show (Pattern3x3 a) where show (P3x3 ix _ _) = show ix instance Show (Pattern5x5 a) where show (P5x5 ix _ _) = show ix instance Show (Pattern7x7 a) where show (P7x7 ix _ _) = show ix @@ -338,6 +415,13 @@ instance Show (Pattern9x9 a) where show (P9x9 ix _ _) = show ix instance Show (Pattern3x3x3 a) where show (P3x3x3 ix _ _) = show ix +pattern1 :: (P.Num a, A.Num a) => Gen (Pattern1 a) +pattern1 = do + i <- Gen.subsequence [0] + pure $ + P1 i (\(Unary x0) -> P.sum (P.map ([x0] P.!!) i)) + (\(Unary x0) -> P.sum (P.map ([x0] P.!!) i)) + pattern3 :: (P.Num a, A.Num a) => Gen (Pattern3 a) pattern3 = do i <- Gen.subsequence [0..2] @@ -366,6 +450,24 @@ pattern9 = do P9 i (\(x0,x1,x2,x3,x4,x5,x6,x7,x8) -> P.sum (P.map ([x0,x1,x2,x3,x4,x5,x6,x7,x8] P.!!) i)) (\(x0,x1,x2,x3,x4,x5,x6,x7,x8) -> P.sum (P.map ([x0,x1,x2,x3,x4,x5,x6,x7,x8] P.!!) i)) +pattern1x3 :: (P.Num a, A.Num a) => Gen (Pattern1x3 a) +pattern1x3 = do + P1 i0 a0 r0 <- pattern1 + P1 i1 a1 r1 <- pattern1 + P1 i2 a2 r2 <- pattern1 + pure $ + P1x3 [i0,i1,i2] + (\(x0,x1,x2) -> P.sum [a0 x0, a1 x1, a2 x2]) + (\(x0,x1,x2) -> P.sum [r0 x0, r1 x1, r2 x2]) + +pattern3x1 :: (P.Num a, A.Num a) => Gen (Pattern3x1 a) +pattern3x1 = do + P3 i0 a0 r0 <- pattern3 + pure $ + P3x1 [i0] + (\(Unary x0) -> P.sum [a0 x0]) + (\(Unary x0) -> P.sum [r0 x0]) + pattern3x3 :: (P.Num a, A.Num a) => Gen (Pattern3x3 a) pattern3x3 = do P3 i0 a0 r0 <- pattern3 @@ -430,6 +532,20 @@ pattern3x3x3 = do +stencil1Ref + :: Elt a + => (Stencil1Ref a -> a) + -> SimpleBoundary a + -> Vector a + -> Vector a +stencil1Ref st bnd arr = + let sh = S.shape arr + in + fromFunction sh + (\ix@(Z:.n) -> let x = arr S.! ix + in + st (Unary x)) + stencil3Ref :: Elt a => (Stencil3Ref a -> a) @@ -515,6 +631,43 @@ stencil9Ref st bnd arr = in st (x0,x1,x2,x3,x4,x5,x6,x7,x8)) +stencil1x3Ref + :: Elt a + => (Stencil1x3Ref a -> a) + -> SimpleBoundary a + -> Matrix a + -> Matrix a +stencil1x3Ref st bnd arr = + let sh = S.shape arr + in + fromFunction sh + (\(Z:.j:.i) -> + let get it = either id (arr S.!) (bound bnd sh it) + -- + x0 = get (Z :. j-1 :. i) + x1 = get (Z :. j :. i) + x2 = get (Z :. j+1 :. i) + in + st (Unary x0, Unary x1, Unary x2)) + + +stencil3x1Ref + :: Elt a + => (Stencil3x1Ref a -> a) + -> SimpleBoundary a + -> Matrix a + -> Matrix a +stencil3x1Ref st bnd arr = + let sh = S.shape arr + in + fromFunction sh + (\(Z:.j:.i) -> + let get it = either id (arr S.!) (bound bnd sh it) + -- + x = ( get (Z :. j :. i-1), get (Z :. j :. i), get (Z :. j :. i+1) ) + in + st (Unary x)) + stencil3x3Ref :: Elt a => (Stencil3x3Ref a -> a) diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 67ead04f0..dbac77050 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -59,7 +59,7 @@ import Data.Array.Accelerate.Representation.Shape hiding ( zip import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Smart as Smart hiding ( StencilR ) +import Data.Array.Accelerate.Smart as Smart import Data.Array.Accelerate.Sugar.Array hiding ( Array, ArraysR, (!!) ) import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Trafo.Config