Skip to content

Commit

Permalink
flattener cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorTaelin committed Oct 19, 2024
1 parent 1ba8b3c commit 90405be
Showing 1 changed file with 72 additions and 187 deletions.
259 changes: 72 additions & 187 deletions src/Kind/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ parseDef = do
body <- parseTerm
return (pats, body)
let (mat, bods) = unzip rules
let flat = clean (flattenDef mat bods 0) 0
return $ trace ("DONE: " ++ termShow flat) flat
let flat = flattenDef mat bods 0
return $ {-trace ("DONE: " ++ termShow flat)-} flat
]
(_, uses) <- P.getState
let name' = expandUses uses name
Expand All @@ -514,191 +514,6 @@ parsePattern = do
return (PCtr name args)
]

-- Flattener
-- ---------

colName :: [Pattern] -> Maybe String
colName col = foldr (A.<|>) Nothing $ map (\case PVar nam -> Just nam; _ -> Nothing) col

isVar :: Pattern -> Bool
isVar (PVar _) = True
isVar _ = False

countSubPatterns :: Pattern -> Int
countSubPatterns (PCtr _ pats) = length pats
countSubPatterns _ = 0

extractConstructors :: [Pattern] -> [String]
extractConstructors = foldr (\pat acc -> case pat of (PCtr nam _) -> nam:acc ; _ -> acc) []

-- Flattener for pattern matching equations
flattenDef :: [[Pattern]] -> [Term] -> Int -> Term
flattenDef (pats:mat) (bod:bods) fresh =
trace ("Flattening definition with " ++ show (length (pats:mat)) ++ " rows") $
if null pats
then trace "No patterns left, returning body" bod
else
let bods' = bod:bods
(col, mat') = unzip (catMaybes (map uncons (pats:mat)))
in if all isVar col
then trace "All patterns are variables, flattening variables" $ flattenVar col mat' bods' fresh
else trace "ADT patterns found, flattening ADT" $ flattenAdt col mat' bods' fresh
flattenDef _ _ fresh = trace "Error: No patterns or bodies left" $ Hol "flatten error" []

-- Handle variable patterns
flattenVar :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Term
flattenVar col mat bods fresh =
trace ("Flattening variable patterns: " ++ show col) $
let var' = "%x" ++ show fresh ++ "-LAM"
fresh' = fresh + 1
var = maybe var' id (colName col)
bods' = zipWith useVarInBody col bods
bod = flattenDef mat bods' fresh'
in Lam var (\x -> bod)

useVarInBody :: Pattern -> Term -> Term
useVarInBody (PVar nam) bod = Use nam (Ref nam) (\x -> bod)
useVarInBody _ bod = bod

-- Handle ADT patterns
flattenAdt :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Term
flattenAdt col mat bods fresh =
trace ("Flattening ADT patterns: " ++ show col ++ " name: " ++ maybe ("%x" ++ show fresh ++ "-ADT") id (colName col)) $
let var' = "%x" ++ show fresh ++ "-ADT"
fresh'= fresh + 1
var = maybe var' id (colName col)
ctrs' = toList (fromList (extractConstructors col))
nPats = maximum (map countSubPatterns col)
cse = map (processCtr col mat bods nPats fresh' var) ctrs'
dfl = processDefaultCase col mat bods var
cse' = if null (snd dfl) then cse else cse ++ [("_", flattenDef (fst dfl) (snd dfl) fresh')]
bod = App (Mat cse') (Ref var)
in Lam var (\x -> bod)

processCtr :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Int -> String -> String -> (String, Term)
processCtr col mat bods nPats fresh' var ctr =
trace ("Processing constructor: " ++ ctr) $
let (mat', bods') = foldr (processPattern ctr nPats var) ([], []) (zip3 col mat bods)
bod = flattenDef mat' bods' fresh'
in (ctr, bod)

processPattern :: String -> Int -> String -> (Pattern, [Pattern], Term) -> ([[Pattern]], [Term]) -> ([[Pattern]], [Term])
processPattern ctr nPats var (pat, pats, bod) (mat, bods) = case pat of
(PCtr nam newPats) ->
if nam == ctr
then ((newPats ++ pats):mat, bod:bods)
else (mat, bods)
(PVar nam) ->
let newPats = [PVar (nam ++ "." ++ show i) | i <- [0..nPats-1]]
bod' = Use nam (Ref var) (\x -> bod)
in ((newPats ++ pats):mat, bod':bods)

processDefaultCase :: [Pattern] -> [[Pattern]] -> [Term] -> String -> ([[Pattern]], [Term])
processDefaultCase col mat bods var =
trace "Processing default case" $
foldr processDefaultPattern ([], []) (zip3 col mat bods)

processDefaultPattern :: (Pattern, [Pattern], Term) -> ([[Pattern]], [Term]) -> ([[Pattern]], [Term])
processDefaultPattern (pat, pats, bod) (mat', bods') =
-- TODO: make the log above more expressive, by also logging all terms in bods'
-- write below the updated 'trace' statement
trace ("Processing default pattern: " ++ show pat ++ ", " ++ show pats ++ ", " ++ termShow bod ++ ", bods': " ++ show (map termShow bods')) $
case pat of
PVar nam ->
let bod' = Use nam (Ref nam) (\x -> bod)
in ((pat:pats):mat', bod':bods')
_ -> (mat', bods')

-- FIXME: refactor the flattener to avoid needing this
clean :: Term -> Int -> Term
clean term dep = {-trace ("clean " ++ termShower False term dep) $-} maybe (go term dep) id (fix term dep) where

fix (Lam nam bod) dep
| App (Mat cse) (Ref arg) <- bod (Var nam dep)
, nam == arg
= Just (clean (Mat cse) dep)
fix (Use nam (Ref val) bod) dep
| nam == val
= Just (clean (bod (Ref "??")) dep)
fix other dep
= Nothing

go (All nam typ bod) dep =
let typ' = clean typ dep
bod' = \x -> clean (bod x) (dep+1)
in All nam typ' bod'
go (Lam nam bod) dep =
let bod' = \x -> clean (bod x) (dep+1)
in Lam nam bod'
go (App fun arg) dep =
let fun' = clean fun dep
arg' = clean arg dep
in App fun' arg'
go (Ann chk val typ) dep =
let val' = clean val dep
typ' = clean typ dep
in Ann chk val' typ'
go (Slf nam typ bod) dep =
let typ' = clean typ dep
bod' = \x -> clean (bod x) (dep+1)
in Slf nam typ' bod'
go (Ins val) dep =
let val' = clean val dep
in Ins val'
go (Dat scp cts) dep =
let scp' = map (\t -> clean t dep) scp
cts' = map (\(Ctr n t) -> Ctr n (cleanTele t dep)) cts
in Dat scp' cts'
go (Con nam args) dep =
let args' = map (\(n, t) -> (n, clean t dep)) args
in Con nam args'
go (Mat cse) dep =
let cse' = map (\(n, t) -> (n, clean t dep)) cse
in Mat cse'
go (Use nam val bod) dep =
let val' = clean val dep
bod' = \x -> clean (bod x) (dep+1)
in Use nam val' bod'
go (Let nam val bod) dep =
let val' = clean val dep
bod' = \x -> clean (bod x) (dep+1)
in Let nam val' bod'
go (Op2 op a b) dep =
let a' = clean a dep
b' = clean b dep
in Op2 op a' b'
go (Swi zero succ) dep =
let zero' = clean zero dep
succ' = clean succ dep
in Swi zero' succ'
go (Hol nam ctx) dep =
let ctx' = map (\t -> clean t dep) ctx
in Hol nam ctx'
go (Met idx ctx) dep =
let ctx' = map (\t -> clean t dep) ctx
in Met idx ctx'
go (Src cod term) dep =
let term' = clean term dep
in Src cod term'
go (Ref name) dep = Ref name
go Set dep = Set
go U32 dep = U32
go (Num n) dep = Num n
go (Txt s) dep = Txt s
go (Lst ts) dep =
let ts' = map (\t -> clean t dep) ts
in Lst ts'
go (Nat n) dep = Nat n
go (Var nam idx) dep = Var nam idx


cleanTele :: Tele -> Int -> Tele
cleanTele (TRet term) dep = TRet (clean term dep)
cleanTele (TExt nam typ tele) dep =
let typ' = clean typ dep
tele' = \x -> cleanTele (tele x) (dep+1)
in TExt nam typ' tele'

parseUses :: Parser Uses
parseUses = P.many $ P.try $ do
string "use "
Expand Down Expand Up @@ -842,3 +657,73 @@ parseNat = withSrc $ P.try $ do
char '#'
num <- P.many1 digit
return $ Nat (read num)

-- Flattener
-- ---------

-- Flattener for pattern matching equations
flattenDef :: [[Pattern]] -> [Term] -> Int -> Term
flattenDef ([]:mat) (bod:bods) fresh = bod
flattenDef (pats:mat) (bod:bods) fresh
| all isVar col = flattenVarCol col mat' (bod:bods) fresh
| otherwise = flattenAdtCol col mat' (bod:bods) fresh
where (col,mat') = getCol (pats:mat)
flattenDef _ _ _ = error "internal error"

-- Flattens a column with only variables
flattenVarCol :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Term
flattenVarCol col mat bods fresh =
let nam = maybe ("%x" ++ show fresh) id (getColName col)
bod = flattenDef mat bods (fresh + 1)
in Lam nam (\x -> bod)

-- Flattens a column with constructors and possibly variables
flattenAdtCol :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Term
flattenAdtCol col mat bods fresh =
let nam = maybe ("%f" ++ show fresh) id (getColName col)
ctr = map (makeCtrCase col mat bods (getColArity col) (fresh+1) nam) (getColCtrs col)
dfl = makeDflCase col mat bods fresh
in Mat (ctr++dfl)

-- Creates a constructor case: '#Name: body'
makeCtrCase :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Int -> String -> String -> (String, Term)
makeCtrCase col mat bods arity fresh var ctr =
let (mat', bods') = foldr go ([], []) (zip3 col mat bods)
bod = flattenDef mat' bods' fresh
in (ctr, bod)
where go ((PCtr nam ps), pats, bod) (mat, bods)
| nam == ctr = ((ps ++ pats):mat, bod:bods)
| otherwise = (mat, bods)
go ((PVar nam), pats, bod) (mat, bods) =
let ps = [PVar (nam++"."++show i) | i <- [0..arity-1]]
in ((ps ++ pats) : mat, bod:bods)

-- Creates a default case: '#_: body'
makeDflCase :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> [(String, Term)]
makeDflCase col mat bods fresh =
let (mat', bods') = foldr go ([], []) (zip3 col mat bods) in
if null bods' then [] else [("_", flattenDef mat' bods' (fresh+1))]
where go ((PVar nam), pats, bod) (mat, bods) = (((PVar nam):pats):mat, bod:bods)
go (ctr, pats, bod) (mat, bods) = (mat, bods)

-- Helper Functions

isVar :: Pattern -> Bool
isVar (PVar _) = True
isVar _ = False

getArity :: Pattern -> Int
getArity (PCtr _ pats) = length pats
getArity _ = 0

getCol :: [[Pattern]] -> ([Pattern], [[Pattern]])
getCol (pats:mat) = unzip (catMaybes (map uncons (pats:mat)))

getColCtrs :: [Pattern] -> [String]
getColCtrs col = toList . fromList $ foldr (\pat acc -> case pat of (PCtr nam _) -> nam:acc ; _ -> acc) [] col

getColName :: [Pattern] -> Maybe String
getColName col = foldr (A.<|>) Nothing $ map (\case PVar nam -> Just nam; _ -> Nothing) col

getColArity :: [Pattern] -> Int
getColArity col = maximum (map getArity col)

0 comments on commit 90405be

Please sign in to comment.