From 5a1a954294d1046edf72f544982ddc3bc5c59fe3 Mon Sep 17 00:00:00 2001 From: Samuel Breese Date: Thu, 29 Jun 2023 15:59:58 -0400 Subject: [PATCH] Cleanup, add many comments --- .../Yosys/CompositionalTranslation.hs | 126 +++++++++--------- src/SAWScript/Yosys/IR.hs | 41 ++++++ src/SAWScript/Yosys/Netgraph.hs | 39 +----- 3 files changed, 108 insertions(+), 98 deletions(-) diff --git a/src/SAWScript/Yosys/CompositionalTranslation.hs b/src/SAWScript/Yosys/CompositionalTranslation.hs index cde49303e7..d50126ea92 100644 --- a/src/SAWScript/Yosys/CompositionalTranslation.hs +++ b/src/SAWScript/Yosys/CompositionalTranslation.hs @@ -17,13 +17,13 @@ module SAWScript.Yosys.CompositionalTranslation ) where import Control.Lens.TH (makeLenses) + import Control.Lens ((^.)) import Control.Monad (forM, (>=>), void) import Control.Monad.IO.Class (MonadIO(..)) import Control.Exception (throw) import Data.Bifunctor (bimap) -import qualified Data.Maybe as Maybe import Data.Text (Text) import qualified Data.Text as Text import Data.Map (Map) @@ -45,28 +45,32 @@ type CellName = Text type Pattern = [Bitrep] type PatternMap m = Map Pattern ((YosysBitvecConsumer -> Pattern -> m SC.Term) -> m SC.Term) +-- | Information about the state type of a particular cell data CellStateInfo = CellStateInfo - { _cellStateInfoType :: SC.Term -- cell state type - either a bitvector for a $dff, or a record type - , _cellStateInfoCryptolType :: C.Type -- cryptol type for the above - , _cellStateInfoFields :: Maybe (Map Text (SC.Term, C.Type)) -- if the type is a record, the fields of the record + { _cellStateInfoType :: SC.Term -- ^ Cell state type - either a bitvector for a $dff, or a record type + , _cellStateInfoCryptolType :: C.Type -- ^ Cryptol type for the above + , _cellStateInfoFields :: Maybe (Map Text (SC.Term, C.Type)) -- ^ If the type is a record, the fields of the record } makeLenses ''CellStateInfo +-- | The SAWCore representation and SAW/Cryptol type information of a hardware module data TranslatedModule = TranslatedModule - { _translatedModuleStateInfo :: Maybe CellStateInfo -- information about the state type for this module - , _translatedModuleTerm :: SC.Term -- the lambda term for the output record (including state) in terms of the inputs (including state) - , _translatedModuleType :: SC.Term - , _translatedModuleCryptolType :: C.Type + { _translatedModuleStateInfo :: Maybe CellStateInfo -- ^ Information about the state type for this module + , _translatedModuleTerm :: SC.Term -- ^ The lambda term for the output record (including state) in terms of the inputs (including state) + , _translatedModuleType :: SC.Term -- ^ The SAWCore type of that term + , _translatedModuleCryptolType :: C.Type -- ^ The Cryptol type of that term } makeLenses ''TranslatedModule +-- | Information needed when translating a module data TranslationContext m = TranslationContext - { _translationContextModules :: Map ModuleName TranslatedModule - , _translationContextStateTypes :: Map CellName CellStateInfo -- state type for every stateful cell in this module (including sequential submodules) - , _translationContextPatternMap :: PatternMap m -- for each pattern, a term representing that pattern (parameterized by a function to get a term representing any other pattern) + { _translationContextModules :: Map ModuleName TranslatedModule -- ^ Context of previously translated modules + , _translationContextStateTypes :: Map CellName CellStateInfo -- ^ State information for every stateful cell in this module (including sequential submodules) + , _translationContextPatternMap :: PatternMap m -- ^ For each pattern, a term representing that pattern (parameterized by a function to get a term representing any other pattern) } makeLenses ''TranslationContext +-- | Given a module and the context of previously-translated modules, construct a mapping from cell names to state information buildTranslationContextStateTypes :: MonadIO m => SC.SharedContext -> @@ -85,13 +89,14 @@ buildTranslationContextStateTypes sc mods m = do pure $ Just CellStateInfo{..} _ -> pure Nothing +-- | Fetch the actual state term for a cell name, given the term for the __state__ input and information about what stateful cells exist lookupStateFor :: forall m. MonadIO m => SC.SharedContext -> - Map CellName CellStateInfo -> -- state type info for each cell - SC.Term -> -- record term mapping (zenc-ed) cell names to cell states - CellName -> -- cell state to lookup + Map CellName CellStateInfo {- ^ State type info for each cell -} -> + SC.Term {- ^ Record term mapping (zenc-ed) cell names to cell states -} -> + CellName {- ^ Cell state to lookup -} -> m SC.Term lookupStateFor sc states inpst cnm = do let fieldnm = cellIdentifier cnm @@ -101,7 +106,7 @@ lookupStateFor sc states inpst cnm = do insertStateField :: MonadIO m => SC.SharedContext -> - Map Text (SC.Term, C.Type) {- ^ The field types of \"__states__\" -} -> + Map Text (SC.Term, C.Type) {- ^ The field types of __states__ -} -> Map Text (SC.Term, C.Type) {- ^ The mapping to update -} -> m (Map Text (SC.Term, C.Type)) insertStateField sc stateFields fields = do @@ -109,59 +114,29 @@ insertStateField sc stateFields fields = do stateRecordCryptolType <- fieldsToCryptolType stateFields pure $ Map.insert "__state__" (stateRecordType, stateRecordCryptolType) fields -moduleInputPorts :: Module -> Map Text [Bitrep] -moduleInputPorts m = - Map.fromList - . Maybe.mapMaybe - ( \(nm, ip) -> - if ip ^. portDirection == DirectionInput || ip ^. portDirection == DirectionInout - then Just (nm, ip ^. portBits) - else Nothing - ) - . Map.assocs - $ m ^. modulePorts - -moduleOutputPorts :: Module -> Map Text [Bitrep] -moduleOutputPorts m = - Map.fromList - . Maybe.mapMaybe - ( \(nm, ip) -> - if ip ^. portDirection == DirectionOutput || ip ^. portDirection == DirectionInout - then Just (nm, ip ^. portBits) - else Nothing - ) - . Map.assocs - $ m ^. modulePorts - -cellInputConnections :: Cell [b] -> Map Text [b] -cellInputConnections c = Map.intersection (c ^. cellConnections) inp - where - inp = Map.filter (\d -> d == DirectionInput || d == DirectionInout) $ c ^. cellPortDirections - -cellOutputConnections :: Ord b => Cell [b] -> Map Text [b] -cellOutputConnections c = Map.intersection (c ^. cellConnections) out - where - out = Map.filter (\d -> d == DirectionOutput || d == DirectionInout) $ c ^. cellPortDirections - +-- | Construct a mapping from patterns to functions that construct terms for those patterns, given functions that construct terms for other patterns +-- We later "tie the knot" on this mapping given a few known patterns (e.g. module inputs and constants) to obtain actual terms for each pattern. buildPatternMap :: forall m. MonadIO m => SC.SharedContext -> - Map ModuleName TranslatedModule -> -- all previously-translated modules - Map CellName CellStateInfo -> -- state type info for each cell - SC.Term -> -- record term mapping inputs to terms (including a field __state__, a record mapping (zenc-ed) cell names to cell states) - Module -> -- the module being translated + Map ModuleName TranslatedModule {- ^ All previously-translated modules -} -> + Map CellName CellStateInfo {- ^ State type info for each cell -} -> + SC.Term {- ^ Record term mapping inputs to terms (including a field __state__, a record mapping (zenc-ed) cell names to cell states) -} -> + Module {- ^ The module being translated -} -> m (PatternMap m) buildPatternMap sc mods states inp m = do let inputPorts = moduleInputPorts m let inputFields = if Map.null states then void inputPorts else Map.insert "__state__" () $ void inputPorts + -- obtain a term for each input port by looking up their names in the input record inpTerms <- forM (Map.assocs inputPorts) $ \(nm, pat) -> do t <- liftIO $ cryptolRecordSelect sc inputFields inp nm fmap (const . pure) <$> deriveTermsByIndices sc pat t - -- grab the __state__ field from the module inputs + -- grab the __state__ field from the input record minpst <- if Map.null states then pure Nothing else Just <$> cryptolRecordSelect sc inputFields inp "__state__" + -- for each cell, construct a term for each output pattern, parameterized by a lookup function for other patterns ms <- forM (Map.toList $ m ^. moduleCells) $ \(cnm, c) -> do let inpPatterns = case c ^. cellType of @@ -199,6 +174,7 @@ buildPatternMap sc mods states inp m = do [ ("Q", cst) ] _ -> pure $ primCellToMap sc c + let -- given a pattern lookup function build a map from output patterns to terms f :: (YosysBitvecConsumer -> Pattern -> m SC.Term) -> m (Map Pattern SC.Term) @@ -216,6 +192,7 @@ buildPatternMap sc mods states inp m = do case Map.lookup pat pats of Nothing -> panic "buildPatternMap" ["Missing expected output pattern for cell"] Just t -> pure t + -- all of the pattern term functions for all of the cells in the module zeroTerm <- liftIO $ SC.scBvConst sc 1 0 oneTerm <- liftIO $ SC.scBvConst sc 1 1 @@ -236,17 +213,23 @@ buildPatternMap sc mods states inp m = do ] ] +-- | Given a translation context (consisting of the previously translated modules, state information, and pattern map), +-- lookup the term for a given pattern in the pattern map. translatePattern :: MonadIO m => SC.SharedContext -> TranslationContext m -> - YosysBitvecConsumer -> - Pattern -> + YosysBitvecConsumer {- ^ Source of this lookup (for error messages) -} -> + Pattern {- ^ Pattern to look up -} -> m SC.Term translatePattern sc ctx c p = do let pmap = ctx ^. translationContextPatternMap case Map.lookup p pmap of + -- if we find the pattern directly, use it (recursively calling translatePattern if other lookups are necessary) Just f -> f $ translatePattern sc ctx + -- otherwise, we look up each bit individually and concatenate to construct the term. + -- this is not an optimal scheme (e.g. you can imagine patterns [1, 2] and [3, 4] being present and looking up [1, 2, 3, 4]) + -- but it works well enough for now, and I suspect the resulting term size is easy to rewrite away in most cases Nothing -> do one <- liftIO $ SC.scNat sc 1 boolty <- liftIO $ SC.scBoolType sc @@ -259,17 +242,18 @@ translatePattern sc ctx c p = do vecBits <- liftIO $ SC.scVector sc onety bits liftIO $ SC.scJoin sc many one boolty vecBits +-- ^ Given previously translated modules, translate a module. +-- (This is the exported interface to the functionality implemented here.) translateModule :: MonadIO m => SC.SharedContext -> - Map ModuleName TranslatedModule -> - Module -> + Map ModuleName TranslatedModule {- ^ Context of previously-translated modules -} -> + Module {- ^ Yosys module to translate -} -> m TranslatedModule translateModule sc mods m = do + -- gather information about the stateful cells of the module states <- buildTranslationContextStateTypes sc mods m let stateFields = Map.fromList $ bimap cellIdentifier (\cs -> (cs ^. cellStateInfoType, cs ^. cellStateInfoCryptolType)) <$> Map.toList states - - -- description of the state fields of the module _translatedModuleStateInfo <- if Map.null states then pure Nothing else do @@ -281,6 +265,7 @@ translateModule sc mods m = do , _cellStateInfoFields = Just stateFields } + -- construct the module function's domain type (a record of all inputs, and optionally state) let inputPorts = moduleInputPorts m inputFields <- forM inputPorts $ \inp -> do ty <- liftIO . SC.scBitvector sc . fromIntegral $ length inp @@ -291,10 +276,12 @@ translateModule sc mods m = do else insertStateField sc stateFields inputFields domainRecordType <- fieldsToType sc domainFields domainRecordCryptolType <- fieldsToCryptolType domainFields + + -- construct a fresh variable of that type (this will become the parameter to the module function) domainRecordEC <- liftIO $ SC.scFreshEC sc "input" domainRecordType domainRecord <- liftIO $ SC.scExtCns sc domainRecordEC - minpst <- if Map.null states then pure Nothing else Just <$> cryptolRecordSelect sc domainFields domainRecord "__state__" + -- construct a pattern map from that domain record pmap <- buildPatternMap sc mods states domainRecord m let ctx = TranslationContext { _translationContextModules = mods @@ -302,6 +289,9 @@ translateModule sc mods m = do , _translationContextPatternMap = pmap } + -- if this module is stateful, grab the __state__ field from the domain record + minpst <- if Map.null states then pure Nothing else Just <$> cryptolRecordSelect sc domainFields domainRecord "__state__" + -- for each stateful cell, build a term representing the new state for that cell outstMap <- fmap Map.fromList . forM (Map.toList states) $ \(cnm, _cs) -> do case Map.lookup cnm (m ^. moduleCells) of @@ -314,9 +304,11 @@ translateModule sc mods m = do | Just subm <- Map.lookup (c ^. cellType) (ctx ^. translationContextModules) -> do -- otherwise, the cell is a stateful submodule: the new state is obtained from the submodules's update function applied to the inputs and old state let inpPatterns = cellInputConnections c + -- lookup the term for each input to the cell inps <- fmap Map.fromList . forM (Map.toList inpPatterns) $ \(inm, pat) -> (inm,) <$> translatePattern sc ctx (YosysBitvecConsumerCell cnm inm) pat let outPatterns = cellOutputConnections c + -- build a record containing all of the cell's inputs, and (if stateful) the appropriate state field sdomainFields <- case minpst of Nothing -> pure inps Just inpst -> do @@ -324,16 +316,25 @@ translateModule sc mods m = do pure $ Map.insert "__state__" subinpst inps let scodomainFields = if Map.null states then void outPatterns else Map.insert "__state__" () $ void outPatterns sdomainRec <- cryptolRecord sc sdomainFields + -- apply the cell's function to the domain record scodomainRec <- liftIO $ SC.scApply sc (subm ^. translatedModuleTerm) sdomainRec + -- grab the state field from the codomain record (cellIdentifier cnm,) <$> cryptolRecordSelect sc scodomainFields scodomainRec "__state__" _ -> panic "translateModule" ["Malformed stateful cell type"] + + -- build a record for the new value of __state__ outst <- cryptolRecord sc outstMap -- for each module output, collect a term for the output let outputPorts = moduleOutputPorts m outs <- fmap Map.fromList . forM (Map.toList outputPorts) $ \(onm, pat) -> (onm,) <$> translatePattern sc ctx (YosysBitvecConsumerOutputPort onm) pat + + -- construct the return value of the module codomainRecord <- cryptolRecord sc $ if Map.null states then outs else Map.insert "__state__" outst outs + + -- construct the module function's codomain type (a record of all outputs, and optionally state) + -- (this is the type of codomainRecord) outputFields <- forM outputPorts $ \inp -> do ty <- liftIO . SC.scBitvector sc . fromIntegral $ length inp let cty = C.tWord . C.tNum $ length inp @@ -342,8 +343,11 @@ translateModule sc mods m = do codomainRecordType <- fieldsToType sc codomainFields codomainRecordCryptolType <- fieldsToCryptolType codomainFields + -- abstract over the return value - this binds the free variable domainRecord with a lambda _translatedModuleTerm <- liftIO $ SC.scAbstractExts sc [domainRecordEC] codomainRecord + -- the type of _translatedModuleTerm - a function from domainRecordType to codomainRecordType _translatedModuleType <- liftIO $ SC.scFun sc domainRecordType codomainRecordType + -- the same type as a Cryptol type let _translatedModuleCryptolType = C.tFun domainRecordCryptolType codomainRecordCryptolType pure TranslatedModule{..} diff --git a/src/SAWScript/Yosys/IR.hs b/src/SAWScript/Yosys/IR.hs index 4e139c3ee4..7993b48649 100644 --- a/src/SAWScript/Yosys/IR.hs +++ b/src/SAWScript/Yosys/IR.hs @@ -18,10 +18,13 @@ module SAWScript.Yosys.IR where import Control.Lens.TH (makeLenses) +import Control.Lens ((^.)) import Control.Monad.IO.Class (MonadIO(..)) import Control.Exception (throw) +import qualified Data.Maybe as Maybe import Data.Map (Map) +import qualified Data.Map as Map import Data.Text (Text) import qualified Data.Text as Text @@ -133,3 +136,41 @@ loadYosysIR :: MonadIO m => FilePath -> m YosysIR loadYosysIR p = liftIO $ Aeson.eitherDecodeFileStrict p >>= \case Left err -> throw . YosysError $ Text.pack err Right ir -> pure ir + +-- | Return the patterns for all of the input ports of a module +moduleInputPorts :: Module -> Map Text [Bitrep] +moduleInputPorts m = + Map.fromList + . Maybe.mapMaybe + ( \(nm, ip) -> + if ip ^. portDirection == DirectionInput || ip ^. portDirection == DirectionInout + then Just (nm, ip ^. portBits) + else Nothing + ) + . Map.assocs + $ m ^. modulePorts + +-- | Return the patterns for all of the output ports of a module +moduleOutputPorts :: Module -> Map Text [Bitrep] +moduleOutputPorts m = + Map.fromList + . Maybe.mapMaybe + ( \(nm, ip) -> + if ip ^. portDirection == DirectionOutput || ip ^. portDirection == DirectionInout + then Just (nm, ip ^. portBits) + else Nothing + ) + . Map.assocs + $ m ^. modulePorts + +-- | Return the patterns for all of the input connections of a cell +cellInputConnections :: Cell [b] -> Map Text [b] +cellInputConnections c = Map.intersection (c ^. cellConnections) inp + where + inp = Map.filter (\d -> d == DirectionInput || d == DirectionInout) $ c ^. cellPortDirections + +-- | Return the patterns for all of the output connections of a cell +cellOutputConnections :: Ord b => Cell [b] -> Map Text [b] +cellOutputConnections c = Map.intersection (c ^. cellConnections) out + where + out = Map.filter (\d -> d == DirectionOutput || d == DirectionInout) $ c ^. cellPortDirections diff --git a/src/SAWScript/Yosys/Netgraph.hs b/src/SAWScript/Yosys/Netgraph.hs index 5ab66bd3a6..4b9753bc80 100644 --- a/src/SAWScript/Yosys/Netgraph.hs +++ b/src/SAWScript/Yosys/Netgraph.hs @@ -24,7 +24,6 @@ import Control.Monad (forM, foldM) import Control.Monad.IO.Class (MonadIO(..)) import Control.Exception (throw) -import qualified Data.Tuple as Tuple import qualified Data.Maybe as Maybe import qualified Data.List as List import Data.Map (Map) @@ -44,45 +43,11 @@ import SAWScript.Yosys.Utils import SAWScript.Yosys.IR import SAWScript.Yosys.Cell -moduleInputPorts :: Module -> Map Text [Bitrep] -moduleInputPorts m = - Map.fromList - . Maybe.mapMaybe - ( \(nm, ip) -> - if ip ^. portDirection == DirectionInput || ip ^. portDirection == DirectionInout - then Just (nm, ip ^. portBits) - else Nothing - ) - . Map.assocs - $ m ^. modulePorts - -moduleOutputPorts :: Module -> Map Text [Bitrep] -moduleOutputPorts m = - Map.fromList - . Maybe.mapMaybe - ( \(nm, ip) -> - if ip ^. portDirection == DirectionOutput || ip ^. portDirection == DirectionInout - then Just (nm, ip ^. portBits) - else Nothing - ) - . Map.assocs - $ m ^. modulePorts - -cellInputConnections :: Cell [b] -> Map Text [b] -cellInputConnections c = Map.intersection (c ^. cellConnections) inp - where - inp = Map.filter (\d -> d == DirectionInput || d == DirectionInout) $ c ^. cellPortDirections - -cellOutputConnections :: Ord b => Cell [b] -> Map [b] Text -cellOutputConnections c = Map.fromList . fmap Tuple.swap . Map.toList $ Map.intersection (c ^. cellConnections) out - where - out = Map.filter (\d -> d == DirectionOutput || d == DirectionInout) $ c ^. cellPortDirections - cellToEdges :: (Ord b, Eq b) => Cell [b] -> [(b, [b])] cellToEdges c = (, inputBits) <$> outputBits where inputBits = List.nub . mconcat . Map.elems $ cellInputConnections c - outputBits = List.nub . mconcat . Map.keys $ cellOutputConnections c + outputBits = List.nub . mconcat . Map.elems $ cellOutputConnections c -------------------------------------------------------------------------------- -- ** Building a network graph from a Yosys module @@ -219,7 +184,7 @@ netgraphToTerms sc env ng inputs Nothing -> throw $ YosysErrorNoSuchCellType (c ^. cellType) cnm -- once we've built a term, insert it along with each of its bits - ts <- forM (Map.assocs $ cellOutputConnections c) $ \(out, o) -> do + ts <- forM (Map.assocs $ cellOutputConnections c) $ \(o, out) -> do t <- cryptolRecordSelect sc outputFields r o deriveTermsByIndices sc out t pure $ Map.union (Map.unions ts) acc