From ca4f83c9e2b145223ba2d8cab929f7947679de9f Mon Sep 17 00:00:00 2001 From: Kyle Miller Date: Mon, 27 Nov 2023 18:37:21 -0800 Subject: [PATCH] feat: use `letFun` function for `let_fun` instead of annotation --- src/Init/Prelude.lean | 4 ++ src/Lean/Compiler/LCNF/ToLCNF.lean | 11 +++--- src/Lean/Elab/Binders.lean | 5 +-- src/Lean/Elab/PreDefinition/Eqns.lean | 8 +++- src/Lean/Expr.lean | 39 ++++++++----------- src/Lean/Meta/AppBuilder.lean | 11 ++++++ src/Lean/Meta/Tactic/Simp/Main.lean | 3 ++ src/Lean/Meta/WHNF.lean | 4 ++ src/Lean/PrettyPrinter/Delaborator/Basic.lean | 13 +------ .../PrettyPrinter/Delaborator/Builtins.lean | 27 +++++++------ tests/lean/1026.lean.expected.out | 3 +- tests/lean/heapSort.lean.expected.out | 3 +- tests/lean/letFun.lean | 2 +- 13 files changed, 70 insertions(+), 63 deletions(-) diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 1b203be23674..dc8f4bc49771 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -66,6 +66,10 @@ example (b : Bool) : Function.const Bool 10 b = 10 := @[inline] def Function.const {α : Sort u} (β : Sort v) (a : α) : β → α := fun _ => a +/-- The encoding of `let_fun x := v; y` is `letFun v (fun x => y)`, +which is an abbreviation for `(fun x => y) v`. -/ +def letFun {α : Sort u} {β : α → Sort v} (v : α) (f : (x : α) → β x) : β v := f v + set_option checkBinderAnnotations false in /-- `inferInstance` synthesizes a value of any target type by typeclass diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 6110dbc16492..00696ffe5b43 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -658,7 +658,9 @@ where visit (f.beta e.getAppArgs) visitApp (e : Expr) : M Arg := do - if let .const declName _ := e.getAppFn then + if let some (n, t, v, b) := e.letFun? then + visitLet (.letE n t v b (nonDep := true)) #[] + else if let .const declName _ := e.getAppFn then if declName == ``Quot.lift then visitQuotLift e else if declName == ``Quot.mk then @@ -725,11 +727,8 @@ where pushElement (.fun funDecl) return .fvar funDecl.fvarId - visitMData (mdata : MData) (e : Expr) : M Arg := do - if let some (.app (.lam n t b ..) v) := letFunAnnotation? (.mdata mdata e) then - visitLet (.letE n t v b (nonDep := true)) #[] - else - visit e + visitMData (_mdata : MData) (e : Expr) : M Arg := do + visit e visitProj (s : Name) (i : Nat) (e : Expr) : M Arg := do match (← visit e) with diff --git a/src/Lean/Elab/Binders.lean b/src/Lean/Elab/Binders.lean index 39218390659b..48d9f08c0493 100644 --- a/src/Lean/Elab/Binders.lean +++ b/src/Lean/Elab/Binders.lean @@ -668,12 +668,11 @@ def elabLetDeclAux (id : Syntax) (binders : Array Syntax) (typeStx : Syntax) (va let body ← instantiateMVars body mkLetFVars #[x] body (usedLetOnly := usedLetOnly) else - let f ← withLocalDecl id.getId (kind := kind) .default type fun x => do + withLocalDecl id.getId (kind := kind) .default type fun x => do addLocalVarInfo id x let body ← elabTermEnsuringType body expectedType? let body ← instantiateMVars body - mkLambdaFVars #[x] body (usedLetOnly := false) - pure <| mkLetFunAnnotation (mkApp f val) + mkLetFun x val body if elabBodyFirst then forallBoundedTelescope type binders.size fun xs type => do -- the original `fvars` from above are gone, so add back info manually diff --git a/src/Lean/Elab/PreDefinition/Eqns.lean b/src/Lean/Elab/PreDefinition/Eqns.lean index 670143db6908..28eebdc3c6e9 100644 --- a/src/Lean/Elab/PreDefinition/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Eqns.lean @@ -24,12 +24,16 @@ structure EqnInfoCore where partial def expand : Expr → Expr | Expr.letE _ _ v b _ => expand (b.instantiate1 v) | Expr.mdata _ b => expand b - | e => e + | e => + if let some (_, _, v, b) := e.letFun? then + expand (b.instantiate1 v) + else + e def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do let target ← mvarId.getType' let some (_, lhs, rhs) := target.eq? | return none - unless rhs.isLet || rhs.isMData do return none + unless rhs.isLet || rhs.isLetFun || rhs.isMData do return none return some (← mvarId.replaceTargetDefEq (← mkEq lhs (expand rhs))) def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 63bdd45d978a..6a8e46833f4a 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -1653,6 +1653,23 @@ def setAppPPExplicitForExposingMVars (e : Expr) : Expr := mkAppN f args |>.setPPExplicit true | _ => e +/-- +Return true if `e` is a `let_fun` expression. +-/ +def isLetFun (e : Expr) : Bool := e.isAppOfArity ``letFun 4 + +/-- Recognizes a `let_fun` expression. For `let_fun n : t := v; b`, returns `some (n, t, v, b)`, +which are the arguments to `Lean.Expr.letE`. + +If in the encoding of `let_fun` the argument to `letFun` is eta reduced, uses `Name.anonymous` for the binder name. -/ +def letFun? (e : Expr) : Option (Name × Expr × Expr × Expr) := + match e with + | .app (.app (.app (.app (.const ``letFun _) α) _β) v) f => + match f with + | .lam n _ b _ => some (n, α, v, b) + | _ => some (.anonymous, α, v, .app f (.bvar 0)) + | _ => none + end Expr /-- @@ -1670,28 +1687,6 @@ def annotation? (kind : Name) (e : Expr) : Option Expr := | .mdata d b => if d.size == 1 && d.getBool kind false then some b else none | _ => none -/-- -Annotate `e` with the `let_fun` annotation. This annotation is used as hint for the delaborator. -If `e` is of the form `(fun x : t => b) v`, then `mkLetFunAnnotation e` is delaborated at -`let_fun x : t := v; b` --/ -def mkLetFunAnnotation (e : Expr) : Expr := - mkAnnotation `let_fun e - -/-- -Return `some e'` if `e = mkLetFunAnnotation e'` --/ -def letFunAnnotation? (e : Expr) : Option Expr := - annotation? `let_fun e - -/-- -Return true if `e = mkLetFunAnnotation e'`, and `e'` is of the form `(fun x : t => b) v` --/ -def isLetFun (e : Expr) : Bool := - match letFunAnnotation? e with - | none => false - | some e => e.isApp && e.appFn!.isLambda - /-- Auxiliary annotation used to mark terms marked with the "inaccessible" annotation `.(t)` and `_` in patterns. diff --git a/src/Lean/Meta/AppBuilder.lean b/src/Lean/Meta/AppBuilder.lean index 39b0f4a408c2..f9a0ce92d45f 100644 --- a/src/Lean/Meta/AppBuilder.lean +++ b/src/Lean/Meta/AppBuilder.lean @@ -24,6 +24,17 @@ def mkExpectedTypeHint (e : Expr) (expectedType : Expr) : MetaM Expr := do let u ← getLevel expectedType return mkApp2 (mkConst ``id [u]) expectedType e +/-- `mkLetFun x v e` creates the encoding for the `let_fun x := v; e` expression. +The expression `x` can either be a free variable or a metavariable, and the function suitably abstracts `x` in `e`. -/ +def mkLetFun (x : Expr) (v : Expr) (e : Expr) : MetaM Expr := do + let f ← mkLambdaFVars #[x] e + let ety ← inferType e + let α ← inferType x + let β ← mkLambdaFVars #[x] ety + let u1 ← getLevel α + let u2 ← getLevel ety + return mkAppN (.const ``letFun [u1, u2]) #[α, β, v, f] + /-- Return `a = b`. -/ def mkEq (a b : Expr) : MetaM Expr := do let aType ← inferType a diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 8fa0e427ead3..322bf75e7474 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -250,6 +250,9 @@ private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do match (← reduceRecMatcher? e) with | some e => return (← reduce e) | none => pure () + if cfg.zeta then + if let some (_, _, v, b) := e.letFun? then + return (← reduce <| b.instantiate1 v) match (← unfold? e) with | some e' => trace[Meta.Tactic.simp.rewrite] "unfold {mkConst e.getAppFn.constName!}, {e} ==> {e'}" diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 7f4ddfebdac6..f732a8f11c44 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -560,6 +560,10 @@ where | .const .. => pure e | .letE _ _ v b _ => if config.zeta then go <| b.instantiate1 v else return e | .app f .. => + if config.zeta then + if let some (_, _, v, b) := e.letFun? then + -- When zeta reducing enabled, always reduce `letFun` no matter the current reducibility level + return (← go <| b.instantiate1 v) let f := f.getAppFn let f' ← go f if config.beta && f'.isLambda then diff --git a/src/Lean/PrettyPrinter/Delaborator/Basic.lean b/src/Lean/PrettyPrinter/Delaborator/Basic.lean index 3f2530b6943b..4ee0d04c0848 100644 --- a/src/Lean/PrettyPrinter/Delaborator/Basic.lean +++ b/src/Lean/PrettyPrinter/Delaborator/Basic.lean @@ -277,17 +277,6 @@ end Delaborator open SubExpr (Pos PosMap) open Delaborator (OptionsPerPos topDownAnalyze) -/-- Custom version of `Lean.Core.betaReduce` to beta reduce expressions for the `pp.beta` option. -We do not want to beta reduce the application in `let_fun` annotations. -/ -private partial def betaReduce' (e : Expr) : CoreM Expr := - Core.transform e (pre := fun e => do - if isLetFun e then - return .done <| e.updateMData! (.app (← betaReduce' e.mdataExpr!.appFn!) (← betaReduce' e.mdataExpr!.appArg!)) - else if e.isHeadBetaTarget then - return .visit e.headBeta - else - return .continue) - def delabCore (e : Expr) (optionsPerPos : OptionsPerPos := {}) (delab := Delaborator.delab) : MetaM (Term × PosMap Elab.Info) := do /- Using `erasePatternAnnotations` here is a bit hackish, but we do it `Expr.mdata` affects the delaborator. TODO: should we fix that? -/ @@ -302,7 +291,7 @@ def delabCore (e : Expr) (optionsPerPos : OptionsPerPos := {}) (delab := Delabor catch _ => pure () withOptions (fun _ => opts) do let e ← if getPPInstantiateMVars opts then instantiateMVars e else pure e - let e ← if getPPBeta opts then betaReduce' e else pure e + let e ← if getPPBeta opts then Core.betaReduce e else pure e let optionsPerPos ← if !getPPAll opts && getPPAnalyze opts && optionsPerPos.isEmpty then topDownAnalyze e diff --git a/src/Lean/PrettyPrinter/Delaborator/Builtins.lean b/src/Lean/PrettyPrinter/Delaborator/Builtins.lean index a75323940882..e97183ff5398 100644 --- a/src/Lean/PrettyPrinter/Delaborator/Builtins.lean +++ b/src/Lean/PrettyPrinter/Delaborator/Builtins.lean @@ -395,19 +395,20 @@ def delabAppMatch : Delab := whenPPOption getPPNotation <| whenPPOption getPPMat return Syntax.mkApp stx st.moreArgs /-- - Delaborate applications of the form `(fun x => b) v` as `let_fun x := v; b` + Delaborate applications of the form `letFun v (fun x => b)` as `let_fun x := v; b` -/ -def delabLetFun : Delab := do - let stxV ← withAppArg delab - withAppFn do - let Expr.lam n _ b _ ← getExpr | unreachable! - let n ← getUnusedName n b - let stxB ← withBindingBody n delab - if ← getPPOption getPPLetVarTypes <||> getPPOption getPPAnalysisLetVarType then - let stxT ← withBindingDomain delab - `(let_fun $(mkIdent n) : $stxT := $stxV; $stxB) - else - `(let_fun $(mkIdent n) := $stxV; $stxB) +@[builtin_delab app.letFun] +def delabLetFun : Delab := whenPPOption getPPNotation do + guard <| (← getExpr).getAppNumArgs == 4 + let Expr.lam n _ b _ := (← getExpr).appArg! | failure + let n ← getUnusedName n b + let stxV ← withAppFn <| withAppArg delab + let stxB ← withAppArg <| withBindingBody n delab + if ← getPPOption getPPLetVarTypes <||> getPPOption getPPAnalysisLetVarType then + let stxT ← SubExpr.withNaryArg 0 delab + `(let_fun $(mkIdent n) : $stxT := $stxV; $stxB) + else + `(let_fun $(mkIdent n) := $stxV; $stxB) @[builtin_delab mdata] def delabMData : Delab := do @@ -417,8 +418,6 @@ def delabMData : Delab := do `(.($s)) -- We only include the inaccessible annotation when we are delaborating patterns else return s - else if isLetFun (← getExpr) && getPPNotation (← getOptions) then - withMDataExpr <| delabLetFun else if let some _ := isLHSGoal? (← getExpr) then withMDataExpr <| withAppFn <| withAppArg <| delab else diff --git a/tests/lean/1026.lean.expected.out b/tests/lean/1026.lean.expected.out index 4d89e3b8dc60..64966ecc9ad9 100644 --- a/tests/lean/1026.lean.expected.out +++ b/tests/lean/1026.lean.expected.out @@ -1,10 +1,9 @@ 1026.lean:1:4-1:7: warning: declaration uses 'sorry' -1026.lean:10:2-10:12: warning: declaration uses 'sorry' 1026.lean:9:8-9:10: warning: declaration uses 'sorry' foo._unfold (n : Nat) : foo n = if n = 0 then 0 else let x := n - 1; - let_fun this := foo.proof_3; + let_fun this := foo.proof_4; foo x diff --git a/tests/lean/heapSort.lean.expected.out b/tests/lean/heapSort.lean.expected.out index ff100c7fc616..36c421cf6bc7 100644 --- a/tests/lean/heapSort.lean.expected.out +++ b/tests/lean/heapSort.lean.expected.out @@ -1,10 +1,12 @@ heapSort.lean:15:4-15:15: warning: declaration uses 'sorry' heapSort.lean:15:4-15:15: warning: declaration uses 'sorry' +heapSort.lean:15:4-15:15: warning: declaration uses 'sorry' heapSort.lean:43:4-43:10: warning: declaration uses 'sorry' heapSort.lean:58:4-58:13: warning: declaration uses 'sorry' heapSort.lean:58:4-58:13: warning: declaration uses 'sorry' heapSort.lean:58:4-58:13: warning: declaration uses 'sorry' heapSort.lean:102:4-102:13: warning: declaration uses 'sorry' +heapSort.lean:102:4-102:13: warning: declaration uses 'sorry' Array.heapSort.loop._eq_1.{u_1} {α : Type u_1} (lt : α → α → Bool) (a : BinaryHeap α fun y x => lt x y) (out : Array α) : Array.heapSort.loop lt a out = @@ -13,4 +15,3 @@ Array.heapSort.loop._eq_1.{u_1} {α : Type u_1} (lt : α → α → Bool) (a : B | some x => let_fun this := (_ : BinaryHeap.size (BinaryHeap.popMax a) < BinaryHeap.size a); Array.heapSort.loop lt (BinaryHeap.popMax a) (Array.push out x) -heapSort.lean:178:11-178:15: warning: declaration uses 'sorry' diff --git a/tests/lean/letFun.lean b/tests/lean/letFun.lean index 796f572f3121..7922b3335547 100644 --- a/tests/lean/letFun.lean +++ b/tests/lean/letFun.lean @@ -5,6 +5,6 @@ f (y + x) example (a b : Nat) (h1 : a = 0) (h2 : b = 0) : (let_fun x := a + 1; x + x) > b := by - simp (config := { beta := false }) [h1] + simp (config := { zeta := false }) [h1] trace_state simp (config := { decide := true }) [h2]