Skip to content

Commit

Permalink
Cleanup, add many comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chameco committed Jun 29, 2023
1 parent 97a2af4 commit 5a1a954
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 98 deletions.
126 changes: 65 additions & 61 deletions src/SAWScript/Yosys/CompositionalTranslation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ module SAWScript.Yosys.CompositionalTranslation
) where

import Control.Lens.TH (makeLenses)

import Control.Lens ((^.))
import Control.Monad (forM, (>=>), void)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Exception (throw)

import Data.Bifunctor (bimap)
import qualified Data.Maybe as Maybe
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Map (Map)
Expand All @@ -45,28 +45,32 @@ type CellName = Text
type Pattern = [Bitrep]
type PatternMap m = Map Pattern ((YosysBitvecConsumer -> Pattern -> m SC.Term) -> m SC.Term)

-- | Information about the state type of a particular cell
data CellStateInfo = CellStateInfo
{ _cellStateInfoType :: SC.Term -- cell state type - either a bitvector for a $dff, or a record type
, _cellStateInfoCryptolType :: C.Type -- cryptol type for the above
, _cellStateInfoFields :: Maybe (Map Text (SC.Term, C.Type)) -- if the type is a record, the fields of the record
{ _cellStateInfoType :: SC.Term -- ^ Cell state type - either a bitvector for a $dff, or a record type
, _cellStateInfoCryptolType :: C.Type -- ^ Cryptol type for the above
, _cellStateInfoFields :: Maybe (Map Text (SC.Term, C.Type)) -- ^ If the type is a record, the fields of the record
}
makeLenses ''CellStateInfo

-- | The SAWCore representation and SAW/Cryptol type information of a hardware module
data TranslatedModule = TranslatedModule
{ _translatedModuleStateInfo :: Maybe CellStateInfo -- information about the state type for this module
, _translatedModuleTerm :: SC.Term -- the lambda term for the output record (including state) in terms of the inputs (including state)
, _translatedModuleType :: SC.Term
, _translatedModuleCryptolType :: C.Type
{ _translatedModuleStateInfo :: Maybe CellStateInfo -- ^ Information about the state type for this module
, _translatedModuleTerm :: SC.Term -- ^ The lambda term for the output record (including state) in terms of the inputs (including state)
, _translatedModuleType :: SC.Term -- ^ The SAWCore type of that term
, _translatedModuleCryptolType :: C.Type -- ^ The Cryptol type of that term
}
makeLenses ''TranslatedModule

-- | Information needed when translating a module
data TranslationContext m = TranslationContext
{ _translationContextModules :: Map ModuleName TranslatedModule
, _translationContextStateTypes :: Map CellName CellStateInfo -- state type for every stateful cell in this module (including sequential submodules)
, _translationContextPatternMap :: PatternMap m -- for each pattern, a term representing that pattern (parameterized by a function to get a term representing any other pattern)
{ _translationContextModules :: Map ModuleName TranslatedModule -- ^ Context of previously translated modules
, _translationContextStateTypes :: Map CellName CellStateInfo -- ^ State information for every stateful cell in this module (including sequential submodules)
, _translationContextPatternMap :: PatternMap m -- ^ For each pattern, a term representing that pattern (parameterized by a function to get a term representing any other pattern)
}
makeLenses ''TranslationContext

-- | Given a module and the context of previously-translated modules, construct a mapping from cell names to state information
buildTranslationContextStateTypes ::
MonadIO m =>
SC.SharedContext ->
Expand All @@ -85,13 +89,14 @@ buildTranslationContextStateTypes sc mods m = do
pure $ Just CellStateInfo{..}
_ -> pure Nothing

-- | Fetch the actual state term for a cell name, given the term for the __state__ input and information about what stateful cells exist
lookupStateFor ::
forall m.
MonadIO m =>
SC.SharedContext ->
Map CellName CellStateInfo -> -- state type info for each cell
SC.Term -> -- record term mapping (zenc-ed) cell names to cell states
CellName -> -- cell state to lookup
Map CellName CellStateInfo {- ^ State type info for each cell -} ->
SC.Term {- ^ Record term mapping (zenc-ed) cell names to cell states -} ->
CellName {- ^ Cell state to lookup -} ->
m SC.Term
lookupStateFor sc states inpst cnm = do
let fieldnm = cellIdentifier cnm
Expand All @@ -101,67 +106,37 @@ lookupStateFor sc states inpst cnm = do
insertStateField ::
MonadIO m =>
SC.SharedContext ->
Map Text (SC.Term, C.Type) {- ^ The field types of \"__states__\" -} ->
Map Text (SC.Term, C.Type) {- ^ The field types of __states__ -} ->
Map Text (SC.Term, C.Type) {- ^ The mapping to update -} ->
m (Map Text (SC.Term, C.Type))
insertStateField sc stateFields fields = do
stateRecordType <- fieldsToType sc stateFields
stateRecordCryptolType <- fieldsToCryptolType stateFields
pure $ Map.insert "__state__" (stateRecordType, stateRecordCryptolType) fields

moduleInputPorts :: Module -> Map Text [Bitrep]
moduleInputPorts m =
Map.fromList
. Maybe.mapMaybe
( \(nm, ip) ->
if ip ^. portDirection == DirectionInput || ip ^. portDirection == DirectionInout
then Just (nm, ip ^. portBits)
else Nothing
)
. Map.assocs
$ m ^. modulePorts

moduleOutputPorts :: Module -> Map Text [Bitrep]
moduleOutputPorts m =
Map.fromList
. Maybe.mapMaybe
( \(nm, ip) ->
if ip ^. portDirection == DirectionOutput || ip ^. portDirection == DirectionInout
then Just (nm, ip ^. portBits)
else Nothing
)
. Map.assocs
$ m ^. modulePorts

cellInputConnections :: Cell [b] -> Map Text [b]
cellInputConnections c = Map.intersection (c ^. cellConnections) inp
where
inp = Map.filter (\d -> d == DirectionInput || d == DirectionInout) $ c ^. cellPortDirections

cellOutputConnections :: Ord b => Cell [b] -> Map Text [b]
cellOutputConnections c = Map.intersection (c ^. cellConnections) out
where
out = Map.filter (\d -> d == DirectionOutput || d == DirectionInout) $ c ^. cellPortDirections

-- | Construct a mapping from patterns to functions that construct terms for those patterns, given functions that construct terms for other patterns
-- We later "tie the knot" on this mapping given a few known patterns (e.g. module inputs and constants) to obtain actual terms for each pattern.
buildPatternMap ::
forall m.
MonadIO m =>
SC.SharedContext ->
Map ModuleName TranslatedModule -> -- all previously-translated modules
Map CellName CellStateInfo -> -- state type info for each cell
SC.Term -> -- record term mapping inputs to terms (including a field __state__, a record mapping (zenc-ed) cell names to cell states)
Module -> -- the module being translated
Map ModuleName TranslatedModule {- ^ All previously-translated modules -} ->
Map CellName CellStateInfo {- ^ State type info for each cell -} ->
SC.Term {- ^ Record term mapping inputs to terms (including a field __state__, a record mapping (zenc-ed) cell names to cell states) -} ->
Module {- ^ The module being translated -} ->
m (PatternMap m)
buildPatternMap sc mods states inp m = do
let inputPorts = moduleInputPorts m
let inputFields = if Map.null states then void inputPorts else Map.insert "__state__" () $ void inputPorts

-- obtain a term for each input port by looking up their names in the input record
inpTerms <- forM (Map.assocs inputPorts) $ \(nm, pat) -> do
t <- liftIO $ cryptolRecordSelect sc inputFields inp nm
fmap (const . pure) <$> deriveTermsByIndices sc pat t

-- grab the __state__ field from the module inputs
-- grab the __state__ field from the input record
minpst <- if Map.null states then pure Nothing else Just <$> cryptolRecordSelect sc inputFields inp "__state__"

-- for each cell, construct a term for each output pattern, parameterized by a lookup function for other patterns
ms <- forM (Map.toList $ m ^. moduleCells) $ \(cnm, c) -> do
let inpPatterns = case c ^. cellType of
Expand Down Expand Up @@ -199,6 +174,7 @@ buildPatternMap sc mods states inp m = do
[ ("Q", cst)
]
_ -> pure $ primCellToMap sc c

let
-- given a pattern lookup function build a map from output patterns to terms
f :: (YosysBitvecConsumer -> Pattern -> m SC.Term) -> m (Map Pattern SC.Term)
Expand All @@ -216,6 +192,7 @@ buildPatternMap sc mods states inp m = do
case Map.lookup pat pats of
Nothing -> panic "buildPatternMap" ["Missing expected output pattern for cell"]
Just t -> pure t

-- all of the pattern term functions for all of the cells in the module
zeroTerm <- liftIO $ SC.scBvConst sc 1 0
oneTerm <- liftIO $ SC.scBvConst sc 1 1
Expand All @@ -236,17 +213,23 @@ buildPatternMap sc mods states inp m = do
]
]

-- | Given a translation context (consisting of the previously translated modules, state information, and pattern map),
-- lookup the term for a given pattern in the pattern map.
translatePattern ::
MonadIO m =>
SC.SharedContext ->
TranslationContext m ->
YosysBitvecConsumer ->
Pattern ->
YosysBitvecConsumer {- ^ Source of this lookup (for error messages) -} ->
Pattern {- ^ Pattern to look up -} ->
m SC.Term
translatePattern sc ctx c p = do
let pmap = ctx ^. translationContextPatternMap
case Map.lookup p pmap of
-- if we find the pattern directly, use it (recursively calling translatePattern if other lookups are necessary)
Just f -> f $ translatePattern sc ctx
-- otherwise, we look up each bit individually and concatenate to construct the term.
-- this is not an optimal scheme (e.g. you can imagine patterns [1, 2] and [3, 4] being present and looking up [1, 2, 3, 4])
-- but it works well enough for now, and I suspect the resulting term size is easy to rewrite away in most cases
Nothing -> do
one <- liftIO $ SC.scNat sc 1
boolty <- liftIO $ SC.scBoolType sc
Expand All @@ -259,17 +242,18 @@ translatePattern sc ctx c p = do
vecBits <- liftIO $ SC.scVector sc onety bits
liftIO $ SC.scJoin sc many one boolty vecBits

-- ^ Given previously translated modules, translate a module.
-- (This is the exported interface to the functionality implemented here.)
translateModule ::
MonadIO m =>
SC.SharedContext ->
Map ModuleName TranslatedModule ->
Module ->
Map ModuleName TranslatedModule {- ^ Context of previously-translated modules -} ->
Module {- ^ Yosys module to translate -} ->
m TranslatedModule
translateModule sc mods m = do
-- gather information about the stateful cells of the module
states <- buildTranslationContextStateTypes sc mods m
let stateFields = Map.fromList $ bimap cellIdentifier (\cs -> (cs ^. cellStateInfoType, cs ^. cellStateInfoCryptolType)) <$> Map.toList states

-- description of the state fields of the module
_translatedModuleStateInfo <- if Map.null states
then pure Nothing
else do
Expand All @@ -281,6 +265,7 @@ translateModule sc mods m = do
, _cellStateInfoFields = Just stateFields
}

-- construct the module function's domain type (a record of all inputs, and optionally state)
let inputPorts = moduleInputPorts m
inputFields <- forM inputPorts $ \inp -> do
ty <- liftIO . SC.scBitvector sc . fromIntegral $ length inp
Expand All @@ -291,17 +276,22 @@ translateModule sc mods m = do
else insertStateField sc stateFields inputFields
domainRecordType <- fieldsToType sc domainFields
domainRecordCryptolType <- fieldsToCryptolType domainFields

-- construct a fresh variable of that type (this will become the parameter to the module function)
domainRecordEC <- liftIO $ SC.scFreshEC sc "input" domainRecordType
domainRecord <- liftIO $ SC.scExtCns sc domainRecordEC

minpst <- if Map.null states then pure Nothing else Just <$> cryptolRecordSelect sc domainFields domainRecord "__state__"
-- construct a pattern map from that domain record
pmap <- buildPatternMap sc mods states domainRecord m
let ctx = TranslationContext
{ _translationContextModules = mods
, _translationContextStateTypes = states
, _translationContextPatternMap = pmap
}

-- if this module is stateful, grab the __state__ field from the domain record
minpst <- if Map.null states then pure Nothing else Just <$> cryptolRecordSelect sc domainFields domainRecord "__state__"

-- for each stateful cell, build a term representing the new state for that cell
outstMap <- fmap Map.fromList . forM (Map.toList states) $ \(cnm, _cs) -> do
case Map.lookup cnm (m ^. moduleCells) of
Expand All @@ -314,26 +304,37 @@ translateModule sc mods m = do
| Just subm <- Map.lookup (c ^. cellType) (ctx ^. translationContextModules) -> do
-- otherwise, the cell is a stateful submodule: the new state is obtained from the submodules's update function applied to the inputs and old state
let inpPatterns = cellInputConnections c
-- lookup the term for each input to the cell
inps <- fmap Map.fromList . forM (Map.toList inpPatterns) $ \(inm, pat) ->
(inm,) <$> translatePattern sc ctx (YosysBitvecConsumerCell cnm inm) pat
let outPatterns = cellOutputConnections c
-- build a record containing all of the cell's inputs, and (if stateful) the appropriate state field
sdomainFields <- case minpst of
Nothing -> pure inps
Just inpst -> do
subinpst <- lookupStateFor sc states inpst cnm
pure $ Map.insert "__state__" subinpst inps
let scodomainFields = if Map.null states then void outPatterns else Map.insert "__state__" () $ void outPatterns
sdomainRec <- cryptolRecord sc sdomainFields
-- apply the cell's function to the domain record
scodomainRec <- liftIO $ SC.scApply sc (subm ^. translatedModuleTerm) sdomainRec
-- grab the state field from the codomain record
(cellIdentifier cnm,) <$> cryptolRecordSelect sc scodomainFields scodomainRec "__state__"
_ -> panic "translateModule" ["Malformed stateful cell type"]

-- build a record for the new value of __state__
outst <- cryptolRecord sc outstMap

-- for each module output, collect a term for the output
let outputPorts = moduleOutputPorts m
outs <- fmap Map.fromList . forM (Map.toList outputPorts) $ \(onm, pat) ->
(onm,) <$> translatePattern sc ctx (YosysBitvecConsumerOutputPort onm) pat

-- construct the return value of the module
codomainRecord <- cryptolRecord sc $ if Map.null states then outs else Map.insert "__state__" outst outs

-- construct the module function's codomain type (a record of all outputs, and optionally state)
-- (this is the type of codomainRecord)
outputFields <- forM outputPorts $ \inp -> do
ty <- liftIO . SC.scBitvector sc . fromIntegral $ length inp
let cty = C.tWord . C.tNum $ length inp
Expand All @@ -342,8 +343,11 @@ translateModule sc mods m = do
codomainRecordType <- fieldsToType sc codomainFields
codomainRecordCryptolType <- fieldsToCryptolType codomainFields

-- abstract over the return value - this binds the free variable domainRecord with a lambda
_translatedModuleTerm <- liftIO $ SC.scAbstractExts sc [domainRecordEC] codomainRecord
-- the type of _translatedModuleTerm - a function from domainRecordType to codomainRecordType
_translatedModuleType <- liftIO $ SC.scFun sc domainRecordType codomainRecordType
-- the same type as a Cryptol type
let _translatedModuleCryptolType = C.tFun domainRecordCryptolType codomainRecordCryptolType

pure TranslatedModule{..}
41 changes: 41 additions & 0 deletions src/SAWScript/Yosys/IR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ module SAWScript.Yosys.IR where

import Control.Lens.TH (makeLenses)

import Control.Lens ((^.))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Exception (throw)

import qualified Data.Maybe as Maybe
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Text (Text)
import qualified Data.Text as Text

Expand Down Expand Up @@ -133,3 +136,41 @@ loadYosysIR :: MonadIO m => FilePath -> m YosysIR
loadYosysIR p = liftIO $ Aeson.eitherDecodeFileStrict p >>= \case
Left err -> throw . YosysError $ Text.pack err
Right ir -> pure ir

-- | Return the patterns for all of the input ports of a module
moduleInputPorts :: Module -> Map Text [Bitrep]
moduleInputPorts m =
Map.fromList
. Maybe.mapMaybe
( \(nm, ip) ->
if ip ^. portDirection == DirectionInput || ip ^. portDirection == DirectionInout
then Just (nm, ip ^. portBits)
else Nothing
)
. Map.assocs
$ m ^. modulePorts

-- | Return the patterns for all of the output ports of a module
moduleOutputPorts :: Module -> Map Text [Bitrep]
moduleOutputPorts m =
Map.fromList
. Maybe.mapMaybe
( \(nm, ip) ->
if ip ^. portDirection == DirectionOutput || ip ^. portDirection == DirectionInout
then Just (nm, ip ^. portBits)
else Nothing
)
. Map.assocs
$ m ^. modulePorts

-- | Return the patterns for all of the input connections of a cell
cellInputConnections :: Cell [b] -> Map Text [b]
cellInputConnections c = Map.intersection (c ^. cellConnections) inp
where
inp = Map.filter (\d -> d == DirectionInput || d == DirectionInout) $ c ^. cellPortDirections

-- | Return the patterns for all of the output connections of a cell
cellOutputConnections :: Ord b => Cell [b] -> Map Text [b]
cellOutputConnections c = Map.intersection (c ^. cellConnections) out
where
out = Map.filter (\d -> d == DirectionOutput || d == DirectionInout) $ c ^. cellPortDirections
Loading

0 comments on commit 5a1a954

Please sign in to comment.