Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unary Stencil #512

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions accelerate.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ library
Data.Array.Accelerate.Sugar.Elt
Data.Array.Accelerate.Sugar.Foreign
Data.Array.Accelerate.Sugar.Shape
Data.Array.Accelerate.Sugar.Stencil
Data.Array.Accelerate.Sugar.Tag
Data.Array.Accelerate.Sugar.Vec
Data.Array.Accelerate.Trafo
Expand Down
15 changes: 9 additions & 6 deletions src/Data/Array/Accelerate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,14 @@ module Data.Array.Accelerate (
clamp, mirror, wrap, function,

-- *** Common stencil patterns
Stencil3, Stencil5, Stencil7, Stencil9,
Stencil3x3, Stencil5x3, Stencil3x5, Stencil5x5,
Stencil3x3x3, Stencil5x3x3, Stencil3x5x3, Stencil3x3x5, Stencil5x5x3, Stencil5x3x5,
Stencil3x5x5, Stencil5x5x5,

Unary(..),
Stencil1, Stencil3, Stencil5, Stencil7, Stencil9,
Stencil1x1, Stencil1x3, Stencil1x5,
Stencil3x1, Stencil3x3, Stencil3x5,
Stencil5x1, Stencil5x3, Stencil5x5,
Stencil3x3x3, Stencil3x5x3, Stencil3x3x5, Stencil3x5x5,
Stencil5x3x3, Stencil5x5x3, Stencil5x3x5, Stencil5x5x5,

-- -- ** Sequence operations
-- collect,

Expand Down Expand Up @@ -344,6 +347,7 @@ module Data.Array.Accelerate (
-- $pattern_synonyms
--
pattern Pattern,
pattern T1,
pattern T2, pattern T3, pattern T4, pattern T5, pattern T6,
pattern T7, pattern T8, pattern T9, pattern T10, pattern T11,
pattern T12, pattern T13, pattern T14, pattern T15, pattern T16,
Expand Down Expand Up @@ -431,7 +435,6 @@ module Data.Array.Accelerate (
CChar, CSChar, CUChar,

) where

import Data.Array.Accelerate.Classes.Bounded
import Data.Array.Accelerate.Classes.Enum
import Data.Array.Accelerate.Classes.Eq
Expand Down
15 changes: 15 additions & 0 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,14 @@ stencilAccess stencil = goR (stencilShapeR stencil) stencil
-- dimension is Z.
--
goR :: ShapeR sh -> StencilR sh e stencil -> (sh -> e) -> sh -> stencil
goR _ (StencilRunit1 _) rf ix =
let
(z, i) = ix
rf' d = rf (z, i+d)
in
( ()
, rf' 0 )

goR _ (StencilRunit3 _) rf ix =
let
(z, i) = ix
Expand Down Expand Up @@ -732,6 +740,13 @@ stencilAccess stencil = goR (stencilShapeR stencil) stencil
-- when we recurse on the stencil structure we must manipulate the
-- _left-most_ index component.
--
goR (ShapeRsnoc shr) (StencilRtup1 s1) rf ix =
let (i, ix') = uncons shr ix
rf' d ds = rf (cons shr (i+d) ds)
in
( ()
, goR shr s1 (rf' 0) ix')

goR (ShapeRsnoc shr) (StencilRtup3 s1 s2 s3) rf ix =
let (i, ix') = uncons shr ix
rf' d ds = rf (cons shr (i+d) ds)
Expand Down
16 changes: 10 additions & 6 deletions src/Data/Array/Accelerate/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ module Data.Array.Accelerate.Language (
-- * Conversions
ord, chr, boolToInt, bitcast,

) where
Stencil1,Stencil1x3,Stencil1x5,Stencil1x1,Stencil3x1,Stencil5x1) where
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A newline before the ) and an indent before the names would be nice here


import Data.Array.Accelerate.AST ( PrimFun(..) )
import Data.Array.Accelerate.Pattern
Expand All @@ -117,8 +117,8 @@ import Data.Array.Accelerate.Classes.Fractional
import Data.Array.Accelerate.Classes.Integral
import Data.Array.Accelerate.Classes.Num
import Data.Array.Accelerate.Classes.Ord

import Prelude ( ($), (.), Maybe(..), Char )
import Data.Array.Accelerate.Sugar.Stencil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports were structured before -- whether that's useful is debatable. But I think import Prelude should be separate because it should catch the eye.



-- $setup
Expand Down Expand Up @@ -851,14 +851,20 @@ backpermute = Acc $$$ applyAcc (Backpermute $ shapeR @sh')
--

-- DIM1 stencil type
type Stencil1 a = Unary (Exp a)
type Stencil3 a = (Exp a, Exp a, Exp a)
type Stencil5 a = (Exp a, Exp a, Exp a, Exp a, Exp a)
type Stencil7 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a)
type Stencil9 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a)

-- DIM2 stencil type
type Stencil1x1 a = Unary (Stencil1 a)
type Stencil3x1 a = Unary (Stencil3 a)
type Stencil5x1 a = Unary (Stencil5 a)
type Stencil1x3 a = (Stencil1 a, Stencil1 a, Stencil1 a)
type Stencil3x3 a = (Stencil3 a, Stencil3 a, Stencil3 a)
type Stencil5x3 a = (Stencil5 a, Stencil5 a, Stencil5 a)
type Stencil1x5 a = (Stencil1 a, Stencil1 a, Stencil1 a, Stencil1 a, Stencil1 a)
type Stencil3x5 a = (Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a)
type Stencil5x5 a = (Stencil5 a, Stencil5 a, Stencil5 a, Stencil5 a, Stencil5 a)

Expand Down Expand Up @@ -908,15 +914,13 @@ type Stencil5x5x5 a = (Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, S
-- <https://en.wikipedia.org/wiki/Gaussian_blur Gaussian blur> as a separable
-- 2-pass operation.
--
-- > type Stencil5x1 a = (Stencil3 a, Stencil5 a, Stencil3 a)
-- > type Stencil1x5 a = (Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a)
-- >
-- > convolve5x1 :: Num a => [Exp a] -> Stencil5x1 a -> Exp a
-- > convolve5x1 kernel (_, (a,b,c,d,e), _)
-- > convolve5x1 kernel (a,b,c,d,e)
-- > = Prelude.sum $ Prelude.zipWith (*) kernel [a,b,c,d,e]
-- >
-- > convolve1x5 :: Num a => [Exp a] -> Stencil1x5 a -> Exp a
-- > convolve1x5 kernel ((_,a,_), (_,b,_), (_,c,_), (_,d,_), (_,e,_))
-- > convolve1x5 kernel (a,b,c,d,e)
-- > = Prelude.sum $ Prelude.zipWith (*) kernel [a,b,c,d,e]
-- >
-- > gaussian = [0.06136,0.24477,0.38774,0.24477,0.06136]
Expand Down
18 changes: 18 additions & 0 deletions src/Data/Array/Accelerate/Pattern.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
module Data.Array.Accelerate.Pattern (

pattern Pattern,
pattern T1,
pattern T2, pattern T3, pattern T4, pattern T5, pattern T6,
pattern T7, pattern T8, pattern T9, pattern T10, pattern T11,
pattern T12, pattern T13, pattern T14, pattern T15, pattern T16,
Expand All @@ -36,6 +37,7 @@ module Data.Array.Accelerate.Pattern (

pattern V2, pattern V3, pattern V4, pattern V8, pattern V16,

Unary (..)
) where

import Data.Array.Accelerate.AST.Idx
Expand Down Expand Up @@ -293,3 +295,19 @@ runQ $ do
vs <- mapM mkV [2,3,4,8,16]
return $ concat (ts ++ is ++ vs)

newtype Unary a = Unary {runUnary :: a}
instance Elt a => Elt (Unary a) where
type EltR (Unary a) = EltR a
eltR = eltR @a
tagsR = tagsR @a
fromElt = fromElt . runUnary
toElt = Unary . toElt

pattern T1 :: Elt a => Exp a -> Exp (Unary a)
pattern T1 a = Pattern (Unary a)
{-# COMPLETE T1 #-}


instance Elt a => IsPattern Exp (Unary a) (Unary (Exp a)) where
builder (Unary (Exp a)) = Exp a
matcher (Exp t) = Unary $ Exp t
19 changes: 19 additions & 0 deletions src/Data/Array/Accelerate/Representation/Stencil.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ import Language.Haskell.TH.Extra
-- | GADT reifying the 'Stencil' class
--
data StencilR sh e pat where
StencilRunit1 :: TypeR e -> StencilR DIM1 e (Tup1 e)
StencilRunit3 :: TypeR e -> StencilR DIM1 e (Tup3 e e e)
StencilRunit5 :: TypeR e -> StencilR DIM1 e (Tup5 e e e e e)
StencilRunit7 :: TypeR e -> StencilR DIM1 e (Tup7 e e e e e e e)
StencilRunit9 :: TypeR e -> StencilR DIM1 e (Tup9 e e e e e e e e e)

StencilRtup1 :: StencilR sh e pat1
-> StencilR (sh, Int) e (Tup1 pat1)
StencilRtup3 :: StencilR sh e pat1
-> StencilR sh e pat2
-> StencilR sh e pat3
Expand Down Expand Up @@ -70,30 +73,36 @@ data StencilR sh e pat where
-> StencilR (sh, Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9)

stencilEltR :: StencilR sh e pat -> TypeR e
stencilEltR (StencilRunit1 t) = t
stencilEltR (StencilRunit3 t) = t
stencilEltR (StencilRunit5 t) = t
stencilEltR (StencilRunit7 t) = t
stencilEltR (StencilRunit9 t) = t
stencilEltR (StencilRtup1 sR) = stencilEltR sR
stencilEltR (StencilRtup3 sR _ _) = stencilEltR sR
stencilEltR (StencilRtup5 sR _ _ _ _) = stencilEltR sR
stencilEltR (StencilRtup7 sR _ _ _ _ _ _) = stencilEltR sR
stencilEltR (StencilRtup9 sR _ _ _ _ _ _ _ _) = stencilEltR sR

stencilShapeR :: StencilR sh e pat -> ShapeR sh
stencilShapeR (StencilRunit1 _) = ShapeRsnoc ShapeRz
stencilShapeR (StencilRunit3 _) = ShapeRsnoc ShapeRz
stencilShapeR (StencilRunit5 _) = ShapeRsnoc ShapeRz
stencilShapeR (StencilRunit7 _) = ShapeRsnoc ShapeRz
stencilShapeR (StencilRunit9 _) = ShapeRsnoc ShapeRz
stencilShapeR (StencilRtup1 sR) = ShapeRsnoc $ stencilShapeR sR
stencilShapeR (StencilRtup3 sR _ _) = ShapeRsnoc $ stencilShapeR sR
stencilShapeR (StencilRtup5 sR _ _ _ _) = ShapeRsnoc $ stencilShapeR sR
stencilShapeR (StencilRtup7 sR _ _ _ _ _ _) = ShapeRsnoc $ stencilShapeR sR
stencilShapeR (StencilRtup9 sR _ _ _ _ _ _ _ _) = ShapeRsnoc $ stencilShapeR sR

stencilR :: StencilR sh e pat -> TypeR pat
stencilR (StencilRunit1 t) = tupR1 t
stencilR (StencilRunit3 t) = tupR3 t t t
stencilR (StencilRunit5 t) = tupR5 t t t t t
stencilR (StencilRunit7 t) = tupR7 t t t t t t t
stencilR (StencilRunit9 t) = tupR9 t t t t t t t t t
stencilR (StencilRtup1 s1) = tupR1 (stencilR s1)
stencilR (StencilRtup3 s1 s2 s3) = tupR3 (stencilR s1) (stencilR s2) (stencilR s3)
stencilR (StencilRtup5 s1 s2 s3 s4 s5) = tupR5 (stencilR s1) (stencilR s2) (stencilR s3) (stencilR s4) (stencilR s5)
stencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) = tupR7 (stencilR s1) (stencilR s2) (stencilR s3) (stencilR s4) (stencilR s5) (stencilR s6) (stencilR s7)
Expand All @@ -106,11 +115,14 @@ stencilHalo :: StencilR sh e stencil -> (ShapeR sh, sh)
stencilHalo = go'
where
go' :: StencilR sh e stencil -> (ShapeR sh, sh)
go' StencilRunit1{} = (dim1, ((), 0))
go' StencilRunit3{} = (dim1, ((), 1))
go' StencilRunit5{} = (dim1, ((), 2))
go' StencilRunit7{} = (dim1, ((), 3))
go' StencilRunit9{} = (dim1, ((), 4))
--
go' (StencilRtup1 a ) = (ShapeRsnoc shR, cons shR 1 $ foldl1 (union shR) [a'])
where (shR, a') = go' a
go' (StencilRtup3 a b c ) = (ShapeRsnoc shR, cons shR 1 $ foldl1 (union shR) [a', go b, go c])
where (shR, a') = go' a
go' (StencilRtup5 a b c d e ) = (ShapeRsnoc shR, cons shR 2 $ foldl1 (union shR) [a', go b, go c, go d, go e])
Expand All @@ -127,6 +139,9 @@ stencilHalo = go'
cons ShapeRz ix () = ((), ix)
cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz)

tupR1 :: TupR s t1 -> TupR s (Tup1 t1)
tupR1 t1 = TupRunit `TupRpair` t1

tupR3 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s (Tup3 t1 t2 t3)
tupR3 t1 t2 t3 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3

Expand All @@ -140,20 +155,24 @@ tupR9 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s t4 -> TupR s t5 -> TupR s
tupR9 t1 t2 t3 t4 t5 t6 t7 t8 t9 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3 `TupRpair` t4 `TupRpair` t5 `TupRpair` t6 `TupRpair` t7 `TupRpair` t8 `TupRpair` t9

rnfStencilR :: StencilR sh e pat -> ()
rnfStencilR (StencilRunit1 t) = rnfTypeR t
rnfStencilR (StencilRunit3 t) = rnfTypeR t
rnfStencilR (StencilRunit5 t) = rnfTypeR t
rnfStencilR (StencilRunit7 t) = rnfTypeR t
rnfStencilR (StencilRunit9 t) = rnfTypeR t
rnfStencilR (StencilRtup1 s1) = rnfStencilR s1
rnfStencilR (StencilRtup3 s1 s2 s3) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3
rnfStencilR (StencilRtup5 s1 s2 s3 s4 s5) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5
rnfStencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 `seq` rnfStencilR s6 `seq` rnfStencilR s7
rnfStencilR (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 `seq` rnfStencilR s6 `seq` rnfStencilR s7 `seq` rnfStencilR s8 `seq` rnfStencilR s9

liftStencilR :: StencilR sh e pat -> CodeQ (StencilR sh e pat)
liftStencilR (StencilRunit1 tp) = [|| StencilRunit1 $$(liftTypeR tp) ||]
liftStencilR (StencilRunit3 tp) = [|| StencilRunit3 $$(liftTypeR tp) ||]
liftStencilR (StencilRunit5 tp) = [|| StencilRunit5 $$(liftTypeR tp) ||]
liftStencilR (StencilRunit7 tp) = [|| StencilRunit7 $$(liftTypeR tp) ||]
liftStencilR (StencilRunit9 tp) = [|| StencilRunit9 $$(liftTypeR tp) ||]
liftStencilR (StencilRtup1 s1) = [|| StencilRtup1 $$(liftStencilR s1) ||]
liftStencilR (StencilRtup3 s1 s2 s3) = [|| StencilRtup3 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) ||]
liftStencilR (StencilRtup5 s1 s2 s3 s4 s5) = [|| StencilRtup5 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) ||]
liftStencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) = [|| StencilRtup7 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) $$(liftStencilR s6) $$(liftStencilR s7) ||]
Expand Down
2 changes: 1 addition & 1 deletion src/Data/Array/Accelerate/Representation/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,5 @@ runQ $
in
tySynD (mkName ("Tup" ++ show n)) (map plainTV xs) rhs
in
mapM mkT [2..16]
mapM mkT [0..16]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zero as well?


Loading