Skip to content

Commit

Permalink
feat: use letFun function for let_fun instead of annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
kmill committed Nov 28, 2023
1 parent 190ac50 commit ca4f83c
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 63 deletions.
4 changes: 4 additions & 0 deletions src/Init/Prelude.lean
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ example (b : Bool) : Function.const Bool 10 b = 10 :=
@[inline] def Function.const {α : Sort u} (β : Sort v) (a : α) : β → α :=
fun _ => a

/-- The encoding of `let_fun x := v; y` is `letFun v (fun x => y)`,
which is an abbreviation for `(fun x => y) v`. -/
def letFun {α : Sort u} {β : α → Sort v} (v : α) (f : (x : α) → β x) : β v := f v

set_option checkBinderAnnotations false in
/--
`inferInstance` synthesizes a value of any target type by typeclass
Expand Down
11 changes: 5 additions & 6 deletions src/Lean/Compiler/LCNF/ToLCNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,9 @@ where
visit (f.beta e.getAppArgs)

visitApp (e : Expr) : M Arg := do
if let .const declName _ := e.getAppFn then
if let some (n, t, v, b) := e.letFun? then
visitLet (.letE n t v b (nonDep := true)) #[]
else if let .const declName _ := e.getAppFn then
if declName == ``Quot.lift then
visitQuotLift e
else if declName == ``Quot.mk then
Expand Down Expand Up @@ -725,11 +727,8 @@ where
pushElement (.fun funDecl)
return .fvar funDecl.fvarId

visitMData (mdata : MData) (e : Expr) : M Arg := do
if let some (.app (.lam n t b ..) v) := letFunAnnotation? (.mdata mdata e) then
visitLet (.letE n t v b (nonDep := true)) #[]
else
visit e
visitMData (_mdata : MData) (e : Expr) : M Arg := do
visit e

visitProj (s : Name) (i : Nat) (e : Expr) : M Arg := do
match (← visit e) with
Expand Down
5 changes: 2 additions & 3 deletions src/Lean/Elab/Binders.lean
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,11 @@ def elabLetDeclAux (id : Syntax) (binders : Array Syntax) (typeStx : Syntax) (va
let body ← instantiateMVars body
mkLetFVars #[x] body (usedLetOnly := usedLetOnly)
else
let f ← withLocalDecl id.getId (kind := kind) .default type fun x => do
withLocalDecl id.getId (kind := kind) .default type fun x => do
addLocalVarInfo id x
let body ← elabTermEnsuringType body expectedType?
let body ← instantiateMVars body
mkLambdaFVars #[x] body (usedLetOnly := false)
pure <| mkLetFunAnnotation (mkApp f val)
mkLetFun x val body
if elabBodyFirst then
forallBoundedTelescope type binders.size fun xs type => do
-- the original `fvars` from above are gone, so add back info manually
Expand Down
8 changes: 6 additions & 2 deletions src/Lean/Elab/PreDefinition/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ structure EqnInfoCore where
partial def expand : Expr → Expr
| Expr.letE _ _ v b _ => expand (b.instantiate1 v)
| Expr.mdata _ b => expand b
| e => e
| e =>
if let some (_, _, v, b) := e.letFun? then
expand (b.instantiate1 v)
else
e

def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do
let target ← mvarId.getType'
let some (_, lhs, rhs) := target.eq? | return none
unless rhs.isLet || rhs.isMData do return none
unless rhs.isLet || rhs.isLetFun || rhs.isMData do return none
return some (← mvarId.replaceTargetDefEq (← mkEq lhs (expand rhs)))

def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do
Expand Down
39 changes: 17 additions & 22 deletions src/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,23 @@ def setAppPPExplicitForExposingMVars (e : Expr) : Expr :=
mkAppN f args |>.setPPExplicit true
| _ => e

/--
Return true if `e` is a `let_fun` expression.
-/
def isLetFun (e : Expr) : Bool := e.isAppOfArity ``letFun 4

/-- Recognizes a `let_fun` expression. For `let_fun n : t := v; b`, returns `some (n, t, v, b)`,
which are the arguments to `Lean.Expr.letE`.
If in the encoding of `let_fun` the argument to `letFun` is eta reduced, uses `Name.anonymous` for the binder name. -/
def letFun? (e : Expr) : Option (Name × Expr × Expr × Expr) :=
match e with
| .app (.app (.app (.app (.const ``letFun _) α) _β) v) f =>
match f with
| .lam n _ b _ => some (n, α, v, b)
| _ => some (.anonymous, α, v, .app f (.bvar 0))
| _ => none

end Expr

/--
Expand All @@ -1670,28 +1687,6 @@ def annotation? (kind : Name) (e : Expr) : Option Expr :=
| .mdata d b => if d.size == 1 && d.getBool kind false then some b else none
| _ => none

/--
Annotate `e` with the `let_fun` annotation. This annotation is used as hint for the delaborator.
If `e` is of the form `(fun x : t => b) v`, then `mkLetFunAnnotation e` is delaborated at
`let_fun x : t := v; b`
-/
def mkLetFunAnnotation (e : Expr) : Expr :=
mkAnnotation `let_fun e

/--
Return `some e'` if `e = mkLetFunAnnotation e'`
-/
def letFunAnnotation? (e : Expr) : Option Expr :=
annotation? `let_fun e

/--
Return true if `e = mkLetFunAnnotation e'`, and `e'` is of the form `(fun x : t => b) v`
-/
def isLetFun (e : Expr) : Bool :=
match letFunAnnotation? e with
| none => false
| some e => e.isApp && e.appFn!.isLambda

/--
Auxiliary annotation used to mark terms marked with the "inaccessible" annotation `.(t)` and
`_` in patterns.
Expand Down
11 changes: 11 additions & 0 deletions src/Lean/Meta/AppBuilder.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ def mkExpectedTypeHint (e : Expr) (expectedType : Expr) : MetaM Expr := do
let u ← getLevel expectedType
return mkApp2 (mkConst ``id [u]) expectedType e

/-- `mkLetFun x v e` creates the encoding for the `let_fun x := v; e` expression.
The expression `x` can either be a free variable or a metavariable, and the function suitably abstracts `x` in `e`. -/
def mkLetFun (x : Expr) (v : Expr) (e : Expr) : MetaM Expr := do
let f ← mkLambdaFVars #[x] e
let ety ← inferType e
let α ← inferType x
let β ← mkLambdaFVars #[x] ety
let u1 ← getLevel α
let u2 ← getLevel ety
return mkAppN (.const ``letFun [u1, u2]) #[α, β, v, f]

/-- Return `a = b`. -/
def mkEq (a b : Expr) : MetaM Expr := do
let aType ← inferType a
Expand Down
3 changes: 3 additions & 0 deletions src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do
match (← reduceRecMatcher? e) with
| some e => return (← reduce e)
| none => pure ()
if cfg.zeta then
if let some (_, _, v, b) := e.letFun? then
return (← reduce <| b.instantiate1 v)
match (← unfold? e) with
| some e' =>
trace[Meta.Tactic.simp.rewrite] "unfold {mkConst e.getAppFn.constName!}, {e} ==> {e'}"
Expand Down
4 changes: 4 additions & 0 deletions src/Lean/Meta/WHNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ where
| .const .. => pure e
| .letE _ _ v b _ => if config.zeta then go <| b.instantiate1 v else return e
| .app f .. =>
if config.zeta then
if let some (_, _, v, b) := e.letFun? then
-- When zeta reducing enabled, always reduce `letFun` no matter the current reducibility level
return (← go <| b.instantiate1 v)
let f := f.getAppFn
let f' ← go f
if config.beta && f'.isLambda then
Expand Down
13 changes: 1 addition & 12 deletions src/Lean/PrettyPrinter/Delaborator/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,6 @@ end Delaborator
open SubExpr (Pos PosMap)
open Delaborator (OptionsPerPos topDownAnalyze)

/-- Custom version of `Lean.Core.betaReduce` to beta reduce expressions for the `pp.beta` option.
We do not want to beta reduce the application in `let_fun` annotations. -/
private partial def betaReduce' (e : Expr) : CoreM Expr :=
Core.transform e (pre := fun e => do
if isLetFun e then
return .done <| e.updateMData! (.app (← betaReduce' e.mdataExpr!.appFn!) (← betaReduce' e.mdataExpr!.appArg!))
else if e.isHeadBetaTarget then
return .visit e.headBeta
else
return .continue)

def delabCore (e : Expr) (optionsPerPos : OptionsPerPos := {}) (delab := Delaborator.delab) : MetaM (Term × PosMap Elab.Info) := do
/- Using `erasePatternAnnotations` here is a bit hackish, but we do it
`Expr.mdata` affects the delaborator. TODO: should we fix that? -/
Expand All @@ -302,7 +291,7 @@ def delabCore (e : Expr) (optionsPerPos : OptionsPerPos := {}) (delab := Delabor
catch _ => pure ()
withOptions (fun _ => opts) do
let e ← if getPPInstantiateMVars opts then instantiateMVars e else pure e
let e ← if getPPBeta opts then betaReduce' e else pure e
let e ← if getPPBeta opts then Core.betaReduce e else pure e
let optionsPerPos ←
if !getPPAll opts && getPPAnalyze opts && optionsPerPos.isEmpty then
topDownAnalyze e
Expand Down
27 changes: 13 additions & 14 deletions src/Lean/PrettyPrinter/Delaborator/Builtins.lean
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,20 @@ def delabAppMatch : Delab := whenPPOption getPPNotation <| whenPPOption getPPMat
return Syntax.mkApp stx st.moreArgs

/--
Delaborate applications of the form `(fun x => b) v` as `let_fun x := v; b`
Delaborate applications of the form `letFun v (fun x => b)` as `let_fun x := v; b`
-/
def delabLetFun : Delab := do
let stxV ← withAppArg delab
withAppFn do
let Expr.lam n _ b _ ← getExpr | unreachable!
let n ← getUnusedName n b
let stxB ← withBindingBody n delab
if ← getPPOption getPPLetVarTypes <||> getPPOption getPPAnalysisLetVarType then
let stxT ← withBindingDomain delab
`(let_fun $(mkIdent n) : $stxT := $stxV; $stxB)
else
`(let_fun $(mkIdent n) := $stxV; $stxB)
@[builtin_delab app.letFun]
def delabLetFun : Delab := whenPPOption getPPNotation do
guard <| (← getExpr).getAppNumArgs == 4
let Expr.lam n _ b _ := (← getExpr).appArg! | failure
let n ← getUnusedName n b
let stxV ← withAppFn <| withAppArg delab
let stxB ← withAppArg <| withBindingBody n delab
if ← getPPOption getPPLetVarTypes <||> getPPOption getPPAnalysisLetVarType then
let stxT ← SubExpr.withNaryArg 0 delab
`(let_fun $(mkIdent n) : $stxT := $stxV; $stxB)
else
`(let_fun $(mkIdent n) := $stxV; $stxB)

@[builtin_delab mdata]
def delabMData : Delab := do
Expand All @@ -417,8 +418,6 @@ def delabMData : Delab := do
`(.($s)) -- We only include the inaccessible annotation when we are delaborating patterns
else
return s
else if isLetFun (← getExpr) && getPPNotation (← getOptions) then
withMDataExpr <| delabLetFun
else if let some _ := isLHSGoal? (← getExpr) then
withMDataExpr <| withAppFn <| withAppArg <| delab
else
Expand Down
3 changes: 1 addition & 2 deletions tests/lean/1026.lean.expected.out
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
1026.lean:1:4-1:7: warning: declaration uses 'sorry'
1026.lean:10:2-10:12: warning: declaration uses 'sorry'
1026.lean:9:8-9:10: warning: declaration uses 'sorry'
foo._unfold (n : Nat) :
foo n =
if n = 0 then 0
else
let x := n - 1;
let_fun this := foo.proof_3;
let_fun this := foo.proof_4;
foo x
3 changes: 2 additions & 1 deletion tests/lean/heapSort.lean.expected.out
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
heapSort.lean:15:4-15:15: warning: declaration uses 'sorry'
heapSort.lean:15:4-15:15: warning: declaration uses 'sorry'
heapSort.lean:15:4-15:15: warning: declaration uses 'sorry'
heapSort.lean:43:4-43:10: warning: declaration uses 'sorry'
heapSort.lean:58:4-58:13: warning: declaration uses 'sorry'
heapSort.lean:58:4-58:13: warning: declaration uses 'sorry'
heapSort.lean:58:4-58:13: warning: declaration uses 'sorry'
heapSort.lean:102:4-102:13: warning: declaration uses 'sorry'
heapSort.lean:102:4-102:13: warning: declaration uses 'sorry'
Array.heapSort.loop._eq_1.{u_1} {α : Type u_1} (lt : α → α → Bool) (a : BinaryHeap α fun y x => lt x y)
(out : Array α) :
Array.heapSort.loop lt a out =
Expand All @@ -13,4 +15,3 @@ Array.heapSort.loop._eq_1.{u_1} {α : Type u_1} (lt : α → α → Bool) (a : B
| some x =>
let_fun this := (_ : BinaryHeap.size (BinaryHeap.popMax a) < BinaryHeap.size a);
Array.heapSort.loop lt (BinaryHeap.popMax a) (Array.push out x)
heapSort.lean:178:11-178:15: warning: declaration uses 'sorry'
2 changes: 1 addition & 1 deletion tests/lean/letFun.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
f (y + x)

example (a b : Nat) (h1 : a = 0) (h2 : b = 0) : (let_fun x := a + 1; x + x) > b := by
simp (config := { beta := false }) [h1]
simp (config := { zeta := false }) [h1]
trace_state
simp (config := { decide := true }) [h2]

0 comments on commit ca4f83c

Please sign in to comment.