Skip to content

Commit

Permalink
remove ensurePi
Browse files Browse the repository at this point in the history
  • Loading branch information
sweirich committed Jun 30, 2023
1 parent 77679d1 commit 0aa614e
Show file tree
Hide file tree
Showing 18 changed files with 298 additions and 409 deletions.
16 changes: 2 additions & 14 deletions doc/oplss.mng
Original file line number Diff line number Diff line change
Expand Up @@ -1398,19 +1398,6 @@ equate :: Term -> Term -> TcMonad ()
This function ensures that the two provided types are equal, or throws a type error if
they are not.

Additionally, the \pif includes the function
% some ``head'' form, without knowing exactly what that form is. For example,
% when \emph{checking} lambda expressions, we need to know that the provided
% type is of the form of a $\Pi$-type ($[[(x:A)-> B]]$). Likewise, when inferring
% the type of an application, we need to know that the type inferred for the
% function is actually a $\Pi$-type.
\begin{verbatim}
ensurePi :: Type -> TcMonad (TName, Type, Type)
\end{verbatim}
that checks the given type to see if the given is equal to some $\Pi$-type of the form
$[[(x:A1)-> A2]]$, and if so returns \texttt{x},
\texttt{A1} and \texttt{A2}.

This function is defined in terms of a helper function that implements the
rules shown in Figure~\ref{fig:whnf}.
\begin{verbatim}
Expand All @@ -1423,7 +1410,8 @@ these functions are called in a few places:
\begin{itemize}
\item \texttt{equate} is called at the end of \texttt{checkType} to make sure
that the annotated type matches the inferred type.
\item \texttt{ensurePi} is called in the \texttt{App} case
\item \texttt{whnf} is called in the \texttt{App} case to ensure that
the function has some sort of function type.
of \texttt{inferType}
\item \texttt{whnf} is called at the beginning of \texttt{checkType}
to make sure that we are using the head form of the type in checking
Expand Down
Binary file modified doc/oplss.pdf
Binary file not shown.
48 changes: 10 additions & 38 deletions full/src/Equal.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
{- pi-forall language -}

-- | Compare two terms for equality
module Equal (whnf, equate, ensurePi, unify
{- SOLN EQUAL -},ensureTyEq {- STUBWITH -}
module Equal (whnf, equate, unify
{- SOLN DATA -},ensureTCon{- STUBWITH -} ) where

import Syntax
Expand All @@ -25,14 +24,14 @@ equate t1 t2 = do
(TyType, TyType) -> return ()
(Var x, Var y) | x == y -> return ()
(Lam ep1 bnd1, Lam ep2 bnd2) -> do
(_, b1, _, b2) <- Unbound.unbind2Plus bnd1 bnd2
(_, b1, b2) <- unbind2 bnd1 bnd2
unless (ep1 == ep2) $
tyErr n1 n2
equate b1 b2
(App a1 a2, App b1 b2) ->
equate a1 b1 >> equateArg a2 b2
(TyPi ep1 tyA1 bnd1, TyPi ep2 tyA2 bnd2) -> do
(_, tyB1, _, tyB2) <- Unbound.unbind2Plus bnd1 bnd2
(_, tyB1, tyB2) <- unbind2 bnd1 bnd2
unless (ep1 == ep2) $
tyErr n1 n2
equate tyA1 tyA2
Expand All @@ -54,12 +53,12 @@ equate t1 t2 = do
equate c1 c2

(Let rhs1 bnd1, Let rhs2 bnd2) -> do
Just (x, body1, _, body2) <- Unbound.unbind2 bnd1 bnd2
(x, body1, body2) <- unbind2 bnd1 bnd2
equate rhs1 rhs2
equate body1 body2

(TySigma tyA1 bnd1, TySigma tyA2 bnd2) -> do
Just (x, tyB1, _, tyB2) <- Unbound.unbind2 bnd1 bnd2
(x, tyB1, tyB2) <- unbind2 bnd1 bnd2
equate tyA1 tyA2
equate tyB1 tyB2

Expand All @@ -69,7 +68,7 @@ equate t1 t2 = do

(LetPair s1 bnd1, LetPair s2 bnd2) -> do
equate s1 s2
Just ((x,y), body1, _, body2) <- Unbound.unbind2 bnd1 bnd2
((x,y), body1, _, body2) <- Unbound.unbind2Plus bnd1 bnd2
equate body1 body2
(TyEq a b, TyEq c d) -> do
equate a c
Expand Down Expand Up @@ -133,32 +132,6 @@ equateArg a1 a2 =


-------------------------------------------------------

-- | Ensure that the given type 'ty' is a 'TyPi' type
-- (or could be normalized to be such) and return the components of
-- the type.
-- Throws an error if this is not the case.
ensurePi :: Type ->
TcMonad (Epsilon, Type, (Unbound.Bind TName Type))
ensurePi ty = do
nf <- whnf ty
case nf of
(TyPi ep tyA bnd) -> do
return (ep, tyA, bnd)
_ -> Env.err [DS "Expected a function type, instead found", DD nf]


-- | Ensure that the given 'ty' is an equality type
-- (or could be normalized to be such) and return
-- the LHS and RHS of that equality
-- Throws an error if this is not the case.
ensureTyEq :: Term -> TcMonad (Term,Term)
ensureTyEq ty = do
nf <- whnf ty
case nf of
TyEq m n -> return (m, n)
_ -> Env.err [DS "Expected an equality type, instead found", DD nf]


-- | Ensure that the given type 'ty' is some tycon applied to
-- params (or could be normalized to be such)
Expand All @@ -185,7 +158,7 @@ whnf (App t1 t2) = do
nf <- whnf t1
case nf of
(Lam ep bnd) -> do
whnf (Unbound.instantiate bnd [unArg t2] )
whnf (instantiate bnd (unArg t2) )
_ -> do
return (App nf t2)

Expand All @@ -207,8 +180,7 @@ whnf (Ann tm _) = whnf tm
whnf (Pos _ tm) = whnf tm

whnf (Let rhs bnd) = do
-- (x,body) <- Unbound.unbind bnd
whnf (Unbound.instantiate bnd [rhs])
whnf (instantiate bnd rhs)
whnf (Subst tm pf) = do
pf' <- whnf pf
case pf' of
Expand Down Expand Up @@ -254,12 +226,12 @@ unify ns tx ty = do
(DataCon s1 a1s, DataCon s2 a2s)
| s1 == s2 -> unifyArgs a1s a2s
(Lam ep1 bnd1, Lam ep2 bnd2) -> do
(x, b1, _, b2) <- Unbound.unbind2Plus bnd1 bnd2
(x, b1, b2) <- unbind2 bnd1 bnd2
unless (ep1 == ep2) $ do
Env.err [DS "Cannot equate", DD txnf, DS "and", DD tynf]
unify (x:ns) b1 b2
(TyPi ep1 tyA1 bnd1, TyPi ep2 tyA2 bnd2) -> do
(x, tyB1, _, tyB2) <- Unbound.unbind2Plus bnd1 bnd2
(x, tyB1, tyB2) <- unbind2 bnd1 bnd2
unless (ep1 == ep2) $ do
Env.err [DS "Cannot equate", DD txnf, DS "and", DD tynf]
ds1 <- unify ns tyA1 tyA2
Expand Down
46 changes: 26 additions & 20 deletions full/src/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Data.Set (Set)
import Data.Set qualified as Set
import Data.Typeable (Typeable)
import GHC.Generics (Generic,from)
import GHC.Base (MonadPlus)
import Text.ParserCombinators.Parsec.Pos (SourcePos, initialPos, newPos)
import Unbound.Generics.LocallyNameless qualified as Unbound
import Unbound.Generics.LocallyNameless.Internal.Fold qualified as Unbound
Expand Down Expand Up @@ -311,6 +312,9 @@ unPosFlaky t = fromMaybe (newPos "unknown location" 0 0) (unPos t)
-- functions for alpha-equivalence, free variables and substitution
-- using generic programming.

-- The definitions below specialize the generic operations from the libary
-- to some of the common uses that we need in pi-forall

-- | Determine when two terms are alpha-equivalent (see below)
aeq :: Term -> Term -> Bool
aeq = Unbound.aeq
Expand All @@ -319,32 +323,35 @@ aeq = Unbound.aeq
fv :: Term -> [Unbound.Name Term]
fv = Unbound.toListOf Unbound.fv

-- | subst x b a means to replace x with b in a
-- | `subst x b a` means to replace `x` with `b` in `a`
-- i.e. a [ b / x ]
subst :: TName -> Term -> Term -> Term
subst = Unbound.subst

-- | in a binder "x.a" replace x with b
-- | in a binder `x.a` replace `x` with `b`
instantiate :: Unbound.Bind TName Term -> Term -> Term
instantiate bnd a = Unbound.instantiate bnd [a]

-- | in a binder "x.a" replace x with a fresh name
-- | in a binder `x.a` replace `x` with a fresh name
unbind :: (Unbound.Fresh m) => Unbound.Bind TName Term -> m (TName, Term)
unbind = Unbound.unbind

-- | in binders `x.a1` and `x.a2` replace `x` with a fresh name in both terms
unbind2 :: (Unbound.Fresh m) => Unbound.Bind TName Term -> Unbound.Bind TName Term -> m (TName, Term, Term)
unbind2 b1 b2 = do
o <- Unbound.unbind2 b1 b2
case o of
Just (x,t,_,u) -> return (x,t,u)
Nothing -> error "impossible"
------------------

-- * Alpha equivalence and free variables
-- * `Alpha` class instances

-- The Unbound library's Alpha class enables the following
-- functions:
-- -- Compare terms for alpha equivalence
-- aeq :: Alpha a => a -> a -> Bool
-- -- Calculate the free variables of a term
-- fv :: Alpha a => a -> [Unbound.Name a]
-- -- Destruct a binding, generating fresh names for the bound variables
-- unbind :: (Alpha p, Alpha t, Fresh m) => Bind p t -> m (p, t)
-- The Unbound library's `Alpha` class enables the `aeq`, `fv`,
-- `instantiate` and `unbind` functions, and also allows some
-- specialization of their generic behavior.

-- For Terms, we'd like Alpha equivalence to ignore
-- For `Term`, we'd like Alpha equivalence to ignore
-- source positions and type annotations. So we make sure to
-- remove them before calling the generic operation.

Expand Down Expand Up @@ -390,13 +397,12 @@ idy = Lam Rel (Unbound.bind yName (Var yName))

-- * Substitution

-- The Subst class derives capture-avoiding substitution
-- It has two parameters because the sort of thing we are substituting
-- The Subst class derives capture-avoiding substitution.
-- It has two parameters because the type of thing we are substituting
-- for may not be the same as what we are substituting into.

-- class Subst b a where
-- subst :: Name b -> b -> a -> a -- single substitution

-- The `isvar` function identifies the variables in the term that
-- should be substituted for.
instance Unbound.Subst Term Term where
isvar (Var x) = Just (Unbound.SubstName x)
isvar _ = Nothing
Expand All @@ -417,8 +423,8 @@ pi2 = TyPi Rel TyBool (Unbound.bind yName (Var yName))
-- * Source Positions

-- SourcePositions do not have an instance of the Generic class available
-- so we cannot automatically define their Alpha and Subst instances. Instead
-- we do so by hand here.
-- so we cannot automatically define their `Alpha` and `Subst` instances.
-- Instead we provide a trivial implementation here.
instance Unbound.Alpha SourcePos where
aeq' _ _ _ = True
fvAny' _ _ = pure
Expand Down
50 changes: 27 additions & 23 deletions full/src/TypeCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import Debug.Trace
import Text.PrettyPrint.HughesPJ (($$), render)

import Unbound.Generics.LocallyNameless qualified as Unbound
import Unbound.Generics.LocallyNameless.Internal.Fold qualified as Unbound
import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind)


Expand All @@ -41,25 +40,24 @@ inferType a = case a of

-- i-pi
(TyPi ep tyA bnd) -> do
(x, tyB) <- Unbound.unbind bnd
(x, tyB) <- unbind bnd
tcType tyA
Env.extendCtx (Decl (TypeDecl x ep tyA)) (tcType tyB)
return TyType

-- i-app
(App a b) -> do
ty1 <- inferType a
let ensurePi = Equal.ensurePi

(ep1, tyA, bnd) <- ensurePi ty1
unless (ep1 == argEp b) $ Env.err
[DS "In application, expected", DD ep1, DS "argument but found",
DD b, DS "instead." ]
-- if the argument is Irrelevant, resurrect the context
(if ep1 == Irr then Env.extendCtx (Demote Rel) else id) $
checkType (unArg b) tyA
return (Unbound.instantiate bnd [unArg b])

ty1' <- Equal.whnf ty1
case ty1' of
(TyPi {- SOLN EP -}ep1 {- STUBWITH -} tyA bnd) -> do
unless (ep1 == argEp b) $ Env.err
[DS "In application, expected", DD ep1, DS "argument but found",
DD b, DS "instead." ]
-- if the argument is Irrelevant, resurrect the context
(if ep1 == Irr then Env.extendCtx (Demote Rel) else id) $ checkType (unArg b) tyA
return (instantiate bnd (unArg b) )
_ -> Env.err [DS "Expected a function type but found ", DD ty1]

-- i-ann
(Ann a tyA) -> do
Expand Down Expand Up @@ -92,7 +90,7 @@ inferType a = case a of

-- i-sigma
(TySigma tyA bnd) -> do
(x, tyB) <- Unbound.unbind bnd
(x, tyB) <- unbind bnd
tcType tyA
Env.extendCtx (mkDecl x tyA) $ tcType tyB
return TyType
Expand Down Expand Up @@ -166,7 +164,7 @@ checkType tm ty = do
(Lam ep1 bnd) -> case ty' of
(TyPi ep2 tyA bnd2) -> do
-- unbind the variables in the lambda expression and pi type
(x, body, _, tyB) <- Unbound.unbind2Plus bnd bnd2
(x, body, tyB) <- unbind2 bnd bnd2
-- epsilons should match up
unless (ep1 == ep2) $ Env.err [DS "In function definition, expected", DD ep2, DS "parameter", DD x,
DS "but found", DD ep1, DS "instead."]
Expand Down Expand Up @@ -197,7 +195,7 @@ checkType tm ty = do
(Prod a b) -> do
case ty' of
(TySigma tyA bnd) -> do
(x, tyB) <- Unbound.unbind bnd
(x, tyB) <- unbind bnd
checkType a tyA
Env.extendCtxs [mkDecl x tyA, Def x a] $ checkType b tyB
_ ->
Expand All @@ -214,15 +212,15 @@ checkType tm ty = do
pty' <- Equal.whnf pty
case pty' of
TySigma tyA bnd' -> do
let tyB = Unbound.instantiate bnd' [Var x]
let tyB = instantiate bnd' (Var x)
decl <- Equal.unify [] p (Prod (Var x) (Var y))
Env.extendCtxs ([mkDecl x tyA, mkDecl y tyB] ++ decl) $
checkType body tyA
_ -> Env.err [DS "Scrutinee of LetPair must have Sigma type"]

-- c-let
(Let a bnd) -> do
(x, b) <- Unbound.unbind bnd
(x, b) <- unbind bnd
tyA <- inferType a
Env.extendCtxs [mkDecl x tyA, Def x a] $
checkType b ty'
Expand All @@ -235,7 +233,10 @@ checkType tm ty = do
-- infer the type of the proof 'b'
tp <- inferType b
-- make sure that it is an equality between m and n
(m, n) <- Equal.ensureTyEq tp
nf <- Equal.whnf tp
(m, n) <- case nf of
TyEq m n -> return (m,n)
_ -> Env.err [DS "Subst requires an equality type, not", DD tp]
-- if either side is a variable, add a definition to the context
edecl <- Equal.unify [] m n
-- if proof is a variable, add a definition to the context
Expand All @@ -244,7 +245,10 @@ checkType tm ty = do
-- c-contra
(Contra p) -> do
ty' <- inferType p
(a, b) <- Equal.ensureTyEq ty'
nf <- Equal.whnf ty'
(a, b) <- case nf of
TyEq m n -> return (m,n)
_ -> Env.err [DS "Contra requires an equality type, not", DD ty']
a' <- Equal.whnf a
b' <- Equal.whnf b
case (a', b') of
Expand Down Expand Up @@ -293,8 +297,8 @@ checkType tm ty = do

(Case scrut alts) -> do
sty <- inferType scrut
scrut' <- Equal.whnf scrut
(c, args) <- Equal.ensureTCon sty
scrut' <- Equal.whnf scrut
let checkAlt (Match bnd) = do
(pat, body) <- Unbound.unbind bnd
-- add variables from pattern to context
Expand All @@ -313,8 +317,8 @@ checkType tm ty = do


-- c-infer
a -> do
tyA <- inferType a
_ -> do
tyA <- inferType tm
Equal.equate tyA ty'


Expand Down
Loading

0 comments on commit 0aa614e

Please sign in to comment.