From 31a1057fe58a0ab4f14ac42cb9abfb2d862e2bbf Mon Sep 17 00:00:00 2001 From: Son HO Date: Wed, 14 Aug 2024 13:28:59 +0200 Subject: [PATCH] Add an AVL tree in the examples (#283) * tests: add AVL example Signed-off-by: Ryan Lahfa * tests: add Lean formalization for AVL Signed-off-by: Ryan Lahfa * Fix minor extraction issues * Start fixing the AVL tree example * Commit some missing files * Make a minor modification * Update the code of the AVL tree * Fix the code of the AVL tree * Regenerate the code of the AVL tree * Make progress on updating the proofs of the AVL trees * Make progress on updating the proofs of the AVL tree * Make minor modifications to the AVLs * Move and update the code of the AVL tree * Regenerate the Lean model for the AVL tree * Generate simp lemmas for the custom field projectors * Regenerate the types for the AVL test * Make progress on the proof of the AVL tree * Make progress on the proofs * Make progress on the proofs of the AVL tree * Make progress on the proofs of the AVL * Cleanup a bit * Cleanup a bit * Do more cleanup * Make progress on the AVL tree * Make good progress on the avl * Make minor modifications * Fix pspec for mutually recursive functions * Regenerate the test files * Improve pspec and progress and cleanup a bit * Add a comment * Fix an issue with the CI and regenerate the tests * Cleanup * Do more cleanup * Make progress generate warnings instead of errors if there are too many ids * Fix a CI issue * Fix a bug in destructure_abs * Cleanup a bit * Fix a minor issue in the flake.nix --------- Signed-off-by: Ryan Lahfa Co-authored-by: Ryan Lahfa --- backends/lean/Base.lean | 1 + backends/lean/Base/Arith/Int.lean | 19 +- backends/lean/Base/Arith/Scalar.lean | 4 + backends/lean/Base/Diverge/Elab.lean | 6 +- backends/lean/Base/Primitives/Base.lean | 1 + backends/lean/Base/Primitives/Core.lean | 21 +- backends/lean/Base/Progress/Base.lean | 44 +- backends/lean/Base/Progress/Progress.lean | 92 +- backends/lean/Base/SimpLemmas.lean | 8 + backends/lean/Base/Termination.lean | 3 + backends/lean/Base/Utils.lean | 4 + compiler/ExtractBuiltin.ml | 11 +- compiler/ExtractTypes.ml | 143 ++- compiler/InterpreterBorrows.ml | 7 + compiler/InterpreterLoopsJoinCtxs.ml | 4 +- flake.nix | 2 + tests/lean/Avl.lean | 1 + tests/lean/Avl/Funs.lean | 238 ++++ tests/lean/Avl/OrderSpec.lean | 156 +++ tests/lean/Avl/Properties.lean | 1030 +++++++++++++++++ tests/lean/Avl/ScalarOrder.lean | 59 + tests/lean/Avl/Types.lean | 70 ++ tests/lean/Betree/Types.lean | 28 +- tests/lean/Hashmap/Properties.lean | 24 +- .../Issue194RecursiveStructProjector.lean | 21 +- tests/lean/MiniTree.lean | 6 +- tests/lean/lakefile.lean | 7 +- tests/src/avl/Cargo.lock | 7 + tests/src/avl/Cargo.toml | 8 + tests/src/avl/aeneas-test-options | 2 + tests/src/avl/rust-toolchain | 3 + tests/src/avl/src/avl.rs | 483 ++++++++ tests/test_runner/run_test.ml | 47 +- 33 files changed, 2474 insertions(+), 86 deletions(-) create mode 100644 backends/lean/Base/SimpLemmas.lean create mode 100644 tests/lean/Avl.lean create mode 100644 tests/lean/Avl/Funs.lean create mode 100644 tests/lean/Avl/OrderSpec.lean create mode 100644 tests/lean/Avl/Properties.lean create mode 100644 tests/lean/Avl/ScalarOrder.lean create mode 100644 tests/lean/Avl/Types.lean create mode 100644 tests/src/avl/Cargo.lock create mode 100644 tests/src/avl/Cargo.toml create mode 100644 tests/src/avl/aeneas-test-options create mode 100644 tests/src/avl/rust-toolchain create mode 100644 tests/src/avl/src/avl.rs diff --git a/backends/lean/Base.lean b/backends/lean/Base.lean index 53baae1e5..e2808f270 100644 --- a/backends/lean/Base.lean +++ b/backends/lean/Base.lean @@ -3,5 +3,6 @@ import Base.Diverge import Base.IList import Base.Primitives import Base.Progress +import Base.SimpLemmas import Base.Utils import Base.Termination diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean index be64089d6..62b696693 100644 --- a/backends/lean/Base/Arith/Int.lean +++ b/backends/lean/Base/Arith/Int.lean @@ -95,7 +95,7 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U -- might have proven the goal, hence the `Tactic.allGoals` let dsimp := Tactic.allGoals do tryTac ( - -- We set `simpOnly` at false on purpose + -- We set `simpOnly` at false on purpose. dsimpAt false {} intTacSimpRocs -- Declarations to unfold [] @@ -107,6 +107,17 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U Tactic.allGoals (Utils.tryTac (Utils.normCastAtAll)) -- norm_cast does weird things with negative numbers so we reapply simp dsimp + -- Int.subNatNat is very annoying - TODO: there is probably something more general thing to do + Utils.tryTac ( + Utils.simpAt true {} + -- Simprocs + [] + -- Unfoldings + [] + -- Simp lemmas + [``Int.subNatNat_eq_coe] + -- Hypotheses + [] .wildcard) -- We also need this, in case the goal is: ¬ False Tactic.allGoals do tryTac ( Utils.simpAt true {} @@ -178,4 +189,10 @@ example (a : Prop) (x : Int) (h0: (0 : Nat) < x) (h1: x < 0) : a := by example (x : Int) (h : x ≤ -3) : x ≤ -2 := by int_tac +example (x y : Int) (h : x + y = 3) : + let z := x + y + z = 3 := by + intro z + omega + end Arith diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index ce2380486..0ee8a2566 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -89,4 +89,8 @@ example (x : I32) : -100000000000 < x.val := by example : (Usize.ofInt 2).val ≠ 0 := by scalar_tac +example (x y : Nat) (z : Int) (h : Int.subNatNat x y + z = 0) : (x : Int) - (y : Int) + z = 0 := by + scalar_tac_preprocess + omega + end Arith diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 609550512..86de54b58 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -661,12 +661,10 @@ mutual -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveExprIsValid: {e}" - -- Normalize to eliminate the lambdas - TODO: this is slightly dangerous. + -- Normalize to eliminate the let-bindings let e ← do if e.isLet ∧ normalize_let_bindings then do - let updt_config (config : Lean.Meta.Config) := - { config with transparency := .reducible } - let e ← withConfig updt_config (whnf e) + let e ← normalizeLetBindings e trace[Diverge.def.valid] "e (after normalization): {e}" pure e else pure e diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index 63fbd8c04..2a8d9c7ae 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -76,6 +76,7 @@ def eval_global {α: Type u} (x: Result α) (_: ok? x := by prove_eval_global) : | fail _ | div => by contradiction | ok x => x +@[simp] def Result.ofOption {a : Type u} (x : Option a) (e : Error) : Result a := match x with | some x => ok x diff --git a/backends/lean/Base/Primitives/Core.lean b/backends/lean/Base/Primitives/Core.lean index aa4a7f285..a49f69f15 100644 --- a/backends/lean/Base/Primitives/Core.lean +++ b/backends/lean/Base/Primitives/Core.lean @@ -51,17 +51,6 @@ def clone.CloneBool : clone.Clone Bool := { clone := fun b => ok (clone.impls.CloneBool.clone b) } -namespace option -- core.option - -/- [core::option::{core::option::Option}::unwrap] -/ -def Option.unwrap (T : Type) (x : Option T) : Result T := - Result.ofOption x Error.panic - -end option -- core.option - -/- [core::option::Option::take] -/ -@[simp] def Option.take (T: Type) (self: Option T): Option T × Option T := (self, .none) - /- [core::mem::replace] This acts like a swap effectively in a functional pure world. @@ -73,3 +62,13 @@ end option -- core.option @[simp] def mem.swap (T: Type) (a b: T): T × T := (b, a) end core + +/- [core::option::{core::option::Option}::unwrap] -/ +@[simp] def core.option.Option.unwrap (T : Type) (x : Option T) : Result T := + Result.ofOption x Error.panic + +/- [core::option::Option::take] -/ +@[simp] def core.option.Option.take (T: Type) (self: Option T): Option T × Option T := (self, .none) + +/- [core::option::Option::is_none] -/ +@[simp] def core.option.Option.is_none (T: Type) (self: Option T): Bool := self.isNone diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 0e46737f1..7cfacdca5 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -81,7 +81,7 @@ section Methods trace[Progress] "After splitting the conjunction:\n- eq: {th}\n- post: {post}" -- Destruct the equality let (mExpr, ret) ← destEq th.consumeMData - trace[Progress] "After splitting the equality:\n- lhs: {th}\n- rhs: {ret}" + trace[Progress] "After splitting the equality:\n- lhs: {mExpr}\n- rhs: {ret}" -- Recursively destruct the monadic application to dive into the binds, -- if necessary (this is for when we use `withPSpec` inside of the `progress` tactic), -- and destruct the application to get the function name @@ -154,18 +154,36 @@ initialize pspecAttr : PSpecAttr ← do throwError "invalid attribute 'pspec', must be global" -- Lookup the theorem let env ← getEnv - let thDecl := env.constants.find! thName - let fKey ← MetaM.run' (do - let fExpr ← getPSpecFunArgsExpr false thDecl.type - trace[Progress] "Registering spec theorem for {fExpr}" - -- Convert the function expression to a discrimination tree key - -- We use the default configuration - let config : WhnfCoreConfig := {} - DiscrTree.mkPath fExpr config) - let env := ext.addEntry env (fKey, thName) - setEnv env - trace[Progress] "Saved the environment" - pure () + -- If we apply the attribute to a definition in a group of mutually recursive definitions + -- (say, to `foo` in the group [`foo`, `bar`]), the attribute gets applied to `foo` but also to + -- the recursive definition which encodes `foo` and `bar` (Lean encodes mutually recursive + -- definitions in one recursive definition, e.g., `foo._mutual`, before deriving the individual + -- definitions, e.g., `foo` and `bar`, from this one). This definition should be named `foo._mutual` + -- or `bar._mutual`, and we want to ignore it. + -- TODO: this is a hack + if let .str _ "_mutual" := thName then + -- Ignore: this is the fixed point of a mutually recursive definition - + -- the attribute will also be applied to the definitions revealed to the user: + -- we want to apply the attribute for those + trace[Progress] "Ignoring a mutually recursive definition: {thName}" + else + trace[Progress] "Registering spec theorem for {thName}" + let thDecl := env.constants.find! thName + let fKey ← MetaM.run' (do + trace[Progress] "Theorem: {thDecl.type}" + -- Normalize to eliminate the let-bindings + let ty ← normalizeLetBindings thDecl.type + trace[Progress] "Theorem after normalization (to eliminate the let bindings): {ty}" + let fExpr ← getPSpecFunArgsExpr false ty + trace[Progress] "Registering spec theorem for expr: {fExpr}" + -- Convert the function expression to a discrimination tree key + -- We use the default configuration + let config : WhnfCoreConfig := {} + DiscrTree.mkPath fExpr config) + let env := ext.addEntry env (fKey, thName) + setEnv env + trace[Progress] "Saved the environment" + pure () } registerBuiltinAttribute attrImpl pure { attr := attrImpl, ext := ext } diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index cea46da8e..12609d7f3 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -73,6 +73,9 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) inferType th | .Local asmDecl => pure asmDecl.type trace[Progress] "Looked up theorem/assumption type: {thTy}" + -- Normalize to inline the let-bindings + let thTy ← normalizeLetBindings thTy + trace[Progress] "After normalizing the let-bindings: {thTy}" -- TODO: the tactic fails if we uncomment withNewMCtxDepth -- withNewMCtxDepth do let (mvars, binders, thExBody) ← forallMetaTelescope thTy @@ -104,6 +107,13 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) | .Local decl => mkAppOptM' (mkFVar decl.fvarId) (mvars.map some) let asmName ← do match keep with | none => mkFreshAnonPropUserName | some n => do pure n let thTy ← inferType th + trace[Progress] "thTy (after application): {thTy}" + -- Normalize the let-bindings (note that we already inlined the let bindings once above when analizing + -- the theorem, now we do it again on the instantiated theorem - there is probably a smarter way to do, + -- but it doesn't really matter). + -- TODO: actually we might want to let the user insert them in the context + let thTy ← normalizeLetBindings thTy + trace[Progress] "thTy (after normalizing let-bindings): {thTy}" let thAsm ← Utils.addDeclTac asmName th thTy (asLet := false) withMainContext do -- The context changed - TODO: remove once addDeclTac is updated let ngoal ← getMainGoal @@ -138,7 +148,11 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) let _ ← tryTac (simpAt true {} [] [] - [``Primitives.bind_tc_ok, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] + [``Primitives.bind_tc_ok, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div, + -- This last one is quite useful. In particular, it is necessary to rewrite the + -- conjunctions for Lean to automatically instantiate the existential quantifiers + -- (I don't know why). + ``and_assoc] [hEq.fvarId!] (.targets #[] true)) -- It may happen that at this point the goal is already solved (though this is rare) -- TODO: not sure this is the best way of checking it @@ -161,8 +175,8 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) | none => -- Sanity check if ¬ ids.isEmpty then - return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split") - else return .Ok + logWarning m!"Too many ids provided ({ids}): there is no postcondition to split" + return .Ok | some hPost => do let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do match ids0 with @@ -183,7 +197,9 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) trace[Progress] "\n- prevId: {prevId}\n- nid: {nid}\n- remaining ids: {ids}" if ← isConj (← inferType hPost) then splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids) - else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition") + else + logWarning m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition" + pure .Ok let curPostId := (← hPost.fvarId!.getDecl).userName splitPostWithIds curPostId hPost ids match res with @@ -459,7 +475,6 @@ namespace Test progress simp [*] - example {α : Type} (v: Vec α) (i: Usize) (x : α) (hbounds : i.val < v.length) : ∃ nv, v.update_usize i x = ok nv ∧ @@ -505,6 +520,73 @@ namespace Test progress keep _ as ⟨ z, h1 .. ⟩ simp [*, h1] + -- Testing with mutually recursive definitions + mutual + inductive Tree + | mk : Trees → Tree + + inductive Trees + | nil + | cons : Tree → Trees → Trees + end + + mutual + def Tree.size (t : Tree) : Result Int := + match t with + | .mk trees => trees.size + + def Trees.size (t : Trees) : Result Int := + match t with + | .nil => ok 0 + | .cons t t' => do + let s ← t.size + let s' ← t'.size + ok (s + s') + end + + mutual + @[pspec] + theorem Tree.size_spec (t : Tree) : + ∃ i, t.size = ok i ∧ i ≥ 0 := by + cases t + simp [Tree.size] + progress + simp + omega + + @[pspec] + theorem Trees.size_spec (t : Trees) : + ∃ i, t.size = ok i ∧ i ≥ 0 := by + cases t <;> simp [Trees.size] + progress + progress + simp + omega + end + + -- Testing progress on theorems containing local let-bindings + def add (x y : U32) : Result U32 := x + y + + @[pspec] -- TODO: give the possibility of using pspec as a local attribute + theorem add_spec x y (h : x.val + y.val ≤ U32.max) : + let tot := x.val + y.val + ∃ z, add x y = ok z ∧ z.val = tot := by + rw [add] + intro tot + progress + simp [*] + + def add1 (x y : U32) : Result U32 := do + let z ← add x y + add z z + + example (x y : U32) (h : 2 * x.val + 2 * y.val ≤ U32.max) : + ∃ z, add1 x y = ok z := by + rw [add1] + progress as ⟨ z1, h ⟩ + progress as ⟨ z2, h ⟩ + simp + end Test end Progress diff --git a/backends/lean/Base/SimpLemmas.lean b/backends/lean/Base/SimpLemmas.lean new file mode 100644 index 000000000..794f5c8a6 --- /dev/null +++ b/backends/lean/Base/SimpLemmas.lean @@ -0,0 +1,8 @@ +import Lean + +-- This simplification lemma is very useful especially for the kind of theorems we prove, +-- which are of the shape: `∃ x y ..., ... ∧ ... ∧ ...`. +-- Using this simp lemma often triggers simplifications like the instantiation of the +-- existential quantifiers when there is an equality somewhere: +-- `∃ x, x = y ∧ P x` gets rewritten to `P y` +attribute [simp] and_assoc diff --git a/backends/lean/Base/Termination.lean b/backends/lean/Base/Termination.lean index de8e678fb..294e8e325 100644 --- a/backends/lean/Base/Termination.lean +++ b/backends/lean/Base/Termination.lean @@ -91,4 +91,7 @@ macro_rules -- Finish simp_all <;> scalar_tac) +-- We always activate this simplification lemma because it is useful for the proofs of termination +attribute [simp] Prod.lex_iff + end Utils diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 9ef628d23..fd628fae3 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -807,4 +807,8 @@ def evalAesopSaturate (options : Aesop.Options') (ruleSets : Array Name) : Tacti |> Aesop.ElabM.runForwardElab (← getMainGoal) tryLiftMetaTactic1 (Aesop.saturate rs · |>.run { options }) "Aesop.saturate failed" +/-- Normalize the let-bindings by inlining them -/ +def normalizeLetBindings (e : Expr) : MetaM Expr := + zetaReduce e + end Utils diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index c8d9af2e4..0232cc49d 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -497,17 +497,14 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = (* Lean-only definitions *) @ mk_lean_only [ - (* `backend_choice` first parameter is for non-Lean backends - By construction, we cannot write down that parameter in the output - in this list - *) mk_fun "core::mem::swap" ~can_fail:false (); mk_fun "core::option::{core::option::Option<@T>}::take" - ~extract_name:(Some (backend_choice "" "Option::take")) + ~extract_name:(Some (backend_choice "" "core::option::Option::take")) ~can_fail:false (); mk_fun "core::option::{core::option::Option<@T>}::is_none" - ~extract_name:(Some (backend_choice "" "Option::isNone")) - ~filter:(Some [ false ]) ~can_fail:false (); + ~extract_name: + (Some (backend_choice "" "core::option::Option::is_none")) + ~can_fail:false (); ] let builtin_funs : unit -> (pattern * bool list option * builtin_fun_info) list diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 12b0debaf..2b3605580 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -1740,7 +1740,7 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) (* Annotate the projectors with both simp and reducible. The first one allows to automatically unfold when calling simp in proofs. The second ensures that projectors will interact well with the unifier *) - F.pp_print_string fmt "@[simp, reducible]"; + F.pp_print_string fmt "@[reducible]"; F.pp_print_break fmt 0 0; (* Close box for the attributes *) F.pp_close_box fmt ()); @@ -1890,6 +1890,142 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) FieldId.iteri extract_field_proj_and_notation fields +(** Auxiliary function. + + Generate field projectors simp lemmas for Lean. + + See {!extract_type_decl_record_field_projectors}. + *) +let extract_type_decl_record_field_projectors_simp_lemmas (ctx : extraction_ctx) + (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit = + let span = decl.item_meta.span in + sanity_check __FILE__ __LINE__ (backend () = Coq || backend () = Lean) span; + match decl.kind with + | Opaque | Enum _ -> () + | Struct fields -> + (* Records are extracted as inductives only if they are recursive *) + let is_rec = decl_is_from_rec_group kind in + if is_rec then + (* Add the type params *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params span decl.item_meta.name decl.llbc_generics + decl.generics ctx + in + (* Name of the ADT *) + let def_name = ctx_get_local_type span decl.def_id ctx in + (* Name of the ADT constructor. As we are in the struct case, we only have + one constructor *) + let cons_name = ctx_get_struct span (TAdtId decl.def_id) ctx in + + let extract_field_proj_simp_lemma (field_id : FieldId.id) (_ : field) : + unit = + F.pp_print_space fmt (); + (* Outer box for the projector definition *) + F.pp_open_hvbox fmt 0; + (* Inner box for the projector definition *) + F.pp_open_hvbox fmt ctx.indent_incr; + + (* For Lean: add some attributes *) + if backend () = Lean then ( + (* Box for the attributes *) + F.pp_open_vbox fmt 0; + (* Annotate the projectors with both simp and reducible. + The first one allows to automatically unfold when calling simp in proofs. + The second ensures that projectors will interact well with the unifier *) + F.pp_print_string fmt "@[simp]"; + F.pp_print_break fmt 0 0; + (* Close box for the attributes *) + F.pp_close_box fmt ()); + + (* Box for the [theorem ... : ... = ... :=] *) + F.pp_open_hovbox fmt ctx.indent_incr; + (match backend () with + | Lean -> F.pp_print_string fmt "theorem" + | _ -> internal_error __FILE__ __LINE__ span); + F.pp_print_space fmt (); + + (* Print the theorem name. *) + let field_name = + ctx_get_field span (TAdtId decl.def_id) field_id ctx + in + (* TODO: check for name collisions *) + F.pp_print_string fmt (def_name ^ "." ^ field_name ^ "._simpLemma_"); + + (* Print the generics *) + let as_implicits = true in + extract_generic_params span ctx fmt TypeDeclId.Set.empty ~as_implicits + decl.generics type_params cg_params trait_clauses; + + (* Print the input parameters (the fields) *) + let print_field (ctx : extraction_ctx) (field_id : FieldId.id) + (f : field) : extraction_ctx * string = + let id = VarId.of_int (FieldId.to_int field_id) in + let field_name = + ctx_get_field span (TAdtId decl.def_id) field_id ctx + in + let ctx, vname = ctx_add_var span field_name id ctx in + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + F.pp_print_string fmt vname; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty span ctx fmt TypeDeclId.Set.empty false f.field_ty; + F.pp_print_string fmt ")"; + (ctx, field_name) + in + let _, field_names = + List.fold_left_map + (fun ctx (id, f) -> print_field ctx id f) + ctx + (FieldId.mapi (fun i f -> (i, f)) fields) + in + + (* The theorem content *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + (* (mk ... x ...).f = x *) + (* Open a box for the theorem content *) + F.pp_open_hvbox fmt ctx.indent_incr; + F.pp_print_string fmt "("; + F.pp_print_string fmt cons_name; + List.iter + (fun f -> + F.pp_print_space fmt (); + F.pp_print_string fmt f) + field_names; + F.pp_print_string fmt (")." ^ field_name); + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + F.pp_print_space fmt (); + F.pp_print_string fmt (FieldId.nth field_names field_id); + (* Close the box for the theorem content *) + F.pp_close_box fmt (); + + (* The proof *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":="; + F.pp_print_space fmt (); + F.pp_print_string fmt "by rfl"; + + (* Close the box for the [theorem ... :=] *) + F.pp_close_box fmt (); + + (* Close the inner box projector *) + F.pp_close_box fmt (); + (* If Coq: end the definition with a "." *) + if backend () = Coq then ( + F.pp_print_cut fmt (); + F.pp_print_string fmt "."); + (* Close the outer box for projector definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + in + + FieldId.iteri extract_field_proj_simp_lemma fields + (** Extract extra information for a type (e.g., [Arguments] instructions in Coq). Note that all the names used for extraction should already have been @@ -1907,7 +2043,10 @@ let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) then ( if backend () = Coq then extract_type_decl_coq_arguments ctx fmt kind decl; - extract_type_decl_record_field_projectors ctx fmt kind decl) + extract_type_decl_record_field_projectors ctx fmt kind decl; + if backend () = Lean then + extract_type_decl_record_field_projectors_simp_lemmas ctx fmt kind + decl) (** Extract the state type declaration. *) let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index f67be365b..52b03be61 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -1803,6 +1803,13 @@ let destructure_abs (span : Meta.span) (abs_kind : abs_kind) (can_end : bool) let value = ALoan (ASharedLoan (PNone, bids, sv, mk_aignored span ty)) in + (* We need to update the type of the value: abstract shared loans + have the type `& ...` - TODO: this is annoying and not very clean... *) + let ty = + (* Take the first region of the abstraction - this doesn't really matter *) + let r = RegionId.Set.min_elt abs0.regions in + TRef (RFVar r, ty, RShared) + in { value; ty } in let avl = List.append [ av ] avl in diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index 4f89b22f2..1ef857aee 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -889,7 +889,9 @@ let destructure_new_abs (span : Meta.span) (loop_id : LoopId.id) else abs) ctx.env in - { ctx with env } + let ctx = { ctx with env } in + Invariants.check_invariants span ctx; + ctx (** Refresh the ids of the fresh abstractions. diff --git a/flake.nix b/flake.nix index 7249be60b..dcd0a7fbd 100644 --- a/flake.nix +++ b/flake.nix @@ -131,6 +131,8 @@ IN_CI=1 make test-all -j $NIX_BUILD_CORES # Clean generated llbc files so we don't compare them. rm -r tests/llbc + # Remove the `target` cache folder generated by cargo + rm -rf tests/src/*/target # Check that there are no differences between the generated tests # and the original tests diff --git a/tests/lean/Avl.lean b/tests/lean/Avl.lean new file mode 100644 index 000000000..4ab2cd5c9 --- /dev/null +++ b/tests/lean/Avl.lean @@ -0,0 +1 @@ +import Avl.Properties diff --git a/tests/lean/Avl/Funs.lean b/tests/lean/Avl/Funs.lean new file mode 100644 index 000000000..f0ae613d2 --- /dev/null +++ b/tests/lean/Avl/Funs.lean @@ -0,0 +1,238 @@ +-- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS +-- [avl]: function definitions +import Base +import Avl.Types +open Primitives +set_option linter.dupNamespace false +set_option linter.hashCommand false +set_option linter.unusedVariables false + +namespace avl + +/- [avl::{avl::Ord for i32}::cmp]: + Source: 'src/avl.rs', lines 8:4-8:43 -/ +def OrdI32.cmp (self : I32) (other : I32) : Result Ordering := + if self < other + then Result.ok Ordering.Less + else + if self = other + then Result.ok Ordering.Equal + else Result.ok Ordering.Greater + +/- Trait implementation: [avl::{avl::Ord for i32}] + Source: 'src/avl.rs', lines 7:0-7:16 -/ +def OrdI32 : Ord I32 := { + cmp := OrdI32.cmp +} + +/- [avl::{avl::Node}#1::rotate_left]: + Source: 'src/avl.rs', lines 41:4-41:65 -/ +def Node.rotate_left + (T : Type) (root : Node T) (z : Node T) : Result (Node T) := + let (b, o) := core.mem.replace (Option (Node T)) z.left none + let (x, root1) := + core.mem.replace (Node T) (Node.mk root.value root.left b + root.balance_factor) (Node.mk z.value o z.right z.balance_factor) + if root1.balance_factor = 0#i8 + then + Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 1#i8)) + root1.right (-1)#i8) + else + Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 0#i8)) + root1.right 0#i8) + +/- [avl::{avl::Node}#1::rotate_right]: + Source: 'src/avl.rs', lines 92:4-92:66 -/ +def Node.rotate_right + (T : Type) (root : Node T) (z : Node T) : Result (Node T) := + let (b, o) := core.mem.replace (Option (Node T)) z.right none + let (x, root1) := + core.mem.replace (Node T) (Node.mk root.value b root.right + root.balance_factor) (Node.mk z.value z.left o z.balance_factor) + if root1.balance_factor = 0#i8 + then + Result.ok (Node.mk root1.value root1.left (some (Node.mk x.value + x.left x.right (-1)#i8)) 1#i8) + else + Result.ok (Node.mk root1.value root1.left (some (Node.mk x.value + x.left x.right 0#i8)) 0#i8) + +/- [avl::{avl::Node}#1::rotate_left_right]: + Source: 'src/avl.rs', lines 138:4-138:72 -/ +def Node.rotate_left_right + (T : Type) (root : Node T) (z : Node T) : Result (Node T) := + do + let (o, _) := core.mem.replace (Option (Node T)) z.right none + let y ← core.option.Option.unwrap (Node T) o + let (a, o1) := core.mem.replace (Option (Node T)) y.left none + let (b, o2) := core.mem.replace (Option (Node T)) y.right none + let (x, root1) := + core.mem.replace (Node T) (Node.mk root.value b root.right + root.balance_factor) (Node.mk y.value o1 o2 y.balance_factor) + if root1.balance_factor = 0#i8 + then + Result.ok (Node.mk root1.value (some (Node.mk z.value z.left a 0#i8)) (some + (Node.mk x.value x.left x.right 0#i8)) 0#i8) + else + if root1.balance_factor < 0#i8 + then + Result.ok (Node.mk root1.value (some (Node.mk z.value z.left a 0#i8)) + (some (Node.mk x.value x.left x.right 1#i8)) 0#i8) + else + Result.ok (Node.mk root1.value (some (Node.mk z.value z.left a (-1)#i8)) + (some (Node.mk x.value x.left x.right 0#i8)) 0#i8) + +/- [avl::{avl::Node}#1::rotate_right_left]: + Source: 'src/avl.rs', lines 188:4-188:72 -/ +def Node.rotate_right_left + (T : Type) (root : Node T) (z : Node T) : Result (Node T) := + do + let (o, _) := core.mem.replace (Option (Node T)) z.left none + let y ← core.option.Option.unwrap (Node T) o + let (b, o1) := core.mem.replace (Option (Node T)) y.left none + let (a, o2) := core.mem.replace (Option (Node T)) y.right none + let (x, root1) := + core.mem.replace (Node T) (Node.mk root.value root.left b + root.balance_factor) (Node.mk y.value o1 o2 y.balance_factor) + if root1.balance_factor = 0#i8 + then + Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 0#i8)) + (some (Node.mk z.value a z.right 0#i8)) 0#i8) + else + if root1.balance_factor > 0#i8 + then + Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right + (-1)#i8)) (some (Node.mk z.value a z.right 0#i8)) 0#i8) + else + Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right + 0#i8)) (some (Node.mk z.value a z.right 1#i8)) 0#i8) + +/- [avl::{avl::Node}#2::insert_in_left]: + Source: 'src/avl.rs', lines 240:4-240:64 -/ +mutual divergent def Node.insert_in_left + (T : Type) (OrdInst : Ord T) (node : Node T) (value : T) : + Result (Bool × (Node T)) + := + do + let (b, o) ← Tree.insert_in_opt_node T OrdInst node.left value + if b + then + do + let i ← node.balance_factor - 1#i8 + if i = (-2)#i8 + then + do + let (o1, o2) := core.mem.replace (Option (Node T)) o none + let left ← core.option.Option.unwrap (Node T) o1 + if left.balance_factor <= 0#i8 + then + do + let node1 ← + Node.rotate_right T (Node.mk node.value o2 node.right i) left + Result.ok (false, node1) + else + do + let node1 ← + Node.rotate_left_right T (Node.mk node.value o2 node.right i) left + Result.ok (false, node1) + else Result.ok (i != 0#i8, Node.mk node.value o node.right i) + else Result.ok (false, Node.mk node.value o node.right node.balance_factor) + +/- [avl::{avl::Tree}#3::insert_in_opt_node]: + Source: 'src/avl.rs', lines 356:4-356:76 -/ +divergent def Tree.insert_in_opt_node + (T : Type) (OrdInst : Ord T) (node : Option (Node T)) (value : T) : + Result (Bool × (Option (Node T))) + := + match node with + | none => let n := Node.mk value none none 0#i8 + Result.ok (true, some n) + | some node1 => + do + let (b, node2) ← Node.insert T OrdInst node1 value + Result.ok (b, some node2) + +/- [avl::{avl::Node}#2::insert_in_right]: + Source: 'src/avl.rs', lines 277:4-277:65 -/ +divergent def Node.insert_in_right + (T : Type) (OrdInst : Ord T) (node : Node T) (value : T) : + Result (Bool × (Node T)) + := + do + let (b, o) ← Tree.insert_in_opt_node T OrdInst node.right value + if b + then + do + let i ← node.balance_factor + 1#i8 + if i = 2#i8 + then + do + let (o1, o2) := core.mem.replace (Option (Node T)) o none + let right ← core.option.Option.unwrap (Node T) o1 + if right.balance_factor >= 0#i8 + then + do + let node1 ← + Node.rotate_left T (Node.mk node.value node.left o2 i) right + Result.ok (false, node1) + else + do + let node1 ← + Node.rotate_right_left T (Node.mk node.value node.left o2 i) right + Result.ok (false, node1) + else Result.ok (i != 0#i8, Node.mk node.value node.left o i) + else Result.ok (false, Node.mk node.value node.left o node.balance_factor) + +/- [avl::{avl::Node}#2::insert]: + Source: 'src/avl.rs', lines 318:4-318:56 -/ +divergent def Node.insert + (T : Type) (OrdInst : Ord T) (node : Node T) (value : T) : + Result (Bool × (Node T)) + := + do + let ordering ← OrdInst.cmp value node.value + match ordering with + | Ordering.Less => Node.insert_in_left T OrdInst node value + | Ordering.Equal => Result.ok (false, node) + | Ordering.Greater => Node.insert_in_right T OrdInst node value + +end + +/- [avl::{avl::Tree}#3::new]: + Source: 'src/avl.rs', lines 338:4-338:24 -/ +def Tree.new (T : Type) (OrdInst : Ord T) : Result (Tree T) := + Result.ok { root := none } + +/- [avl::{avl::Tree}#3::find]: loop 0: + Source: 'src/avl.rs', lines 342:4-354:5 -/ +divergent def Tree.find_loop + (T : Type) (OrdInst : Ord T) (value : T) (current_tree : Option (Node T)) : + Result Bool + := + match current_tree with + | none => Result.ok false + | some current_node => + do + let o ← OrdInst.cmp current_node.value value + match o with + | Ordering.Less => Tree.find_loop T OrdInst value current_node.right + | Ordering.Equal => Result.ok true + | Ordering.Greater => Tree.find_loop T OrdInst value current_node.left + +/- [avl::{avl::Tree}#3::find]: + Source: 'src/avl.rs', lines 342:4-342:40 -/ +def Tree.find + (T : Type) (OrdInst : Ord T) (self : Tree T) (value : T) : Result Bool := + Tree.find_loop T OrdInst value self.root + +/- [avl::{avl::Tree}#3::insert]: + Source: 'src/avl.rs', lines 374:4-374:46 -/ +def Tree.insert + (T : Type) (OrdInst : Ord T) (self : Tree T) (value : T) : + Result (Bool × (Tree T)) + := + do + let (b, o) ← Tree.insert_in_opt_node T OrdInst self.root value + Result.ok (b, { root := o }) + +end avl diff --git a/tests/lean/Avl/OrderSpec.lean b/tests/lean/Avl/OrderSpec.lean new file mode 100644 index 000000000..344edf183 --- /dev/null +++ b/tests/lean/Avl/OrderSpec.lean @@ -0,0 +1,156 @@ +import Avl.Funs +open Primitives +open Result + +namespace avl + +@[simp] +def Ordering.toLeanOrdering (o: avl.Ordering): _root_.Ordering := match o with +| .Less => .lt +| .Equal => .eq +| .Greater => .gt + +@[simp] +def Ordering.ofLeanOrdering (o: _root_.Ordering): avl.Ordering := match o with +| .lt => .Less +| .eq => .Equal +| .gt => .Greater + +@[simp] +def Ordering.toDualOrdering (o: avl.Ordering): avl.Ordering := match o with +| .Less => .Greater +| .Equal => .Equal +| .Greater => .Less + +@[simp] +theorem Ordering.toLeanOrdering.injEq (x y: avl.Ordering): (x.toLeanOrdering = y.toLeanOrdering) = (x = y) := by + apply propext + cases x <;> cases y <;> simp + +@[simp] +theorem ite_eq_lt_distrib (c : Prop) [Decidable c] (a b : Ordering) : + ((if c then a else b) = .Less) = if c then a = .Less else b = .Less := by + by_cases c <;> simp [*] + +@[simp] +theorem ite_eq_eq_distrib (c : Prop) [Decidable c] (a b : Ordering) : + ((if c then a else b) = .Equal) = if c then a = .Equal else b = .Equal := by + by_cases c <;> simp [*] + +@[simp] +theorem ite_eq_gt_distrib (c : Prop) [Decidable c] (a b : Ordering) : + ((if c then a else b) = .Greater) = if c then a = .Greater else b = .Greater := by + by_cases c <;> simp [*] + +variable {T: Type} (H: outParam (Ord T)) + +@[simp] +def _root_.Ordering.toDualOrdering (o: _root_.Ordering): _root_.Ordering := match o with +| .lt => .gt +| .eq => .eq +| .gt => .lt + + +@[simp] +theorem toDualOrderingOfToLeanOrdering (o: avl.Ordering): o.toDualOrdering.toLeanOrdering = o.toLeanOrdering.toDualOrdering := by + cases o <;> simp + +@[simp] +theorem toDualOrderingIdempotency (o: _root_.Ordering): o.toDualOrdering.toDualOrdering = o := by + cases o <;> simp + +-- TODO: reason about raw bundling vs. refined bundling. +-- raw bundling: hypothesis with Rust extracted objects. +-- refined bundling: lifted hypothesis with Lean native objects. +class OrdSpec [_root_.Ord T] where + infallible: ∀ a b, ∃ (o: avl.Ordering), H.cmp a b = .ok o ∧ compare a b = o.toLeanOrdering + +class OrdSpecSymmetry [O: _root_.Ord T] extends OrdSpec H where + symmetry: ∀ a b, O.compare a b = (O.opposite.compare a b).toDualOrdering + +-- Must be R decidableRel and an equivalence relationship? +class OrdSpecRel [O: _root_.Ord T] (R: outParam (T -> T -> Prop)) extends OrdSpec H where + equivalence: ∀ a b, H.cmp a b = .ok .Equal -> R a b + +class OrdSpecLinearOrderEq [O: _root_.Ord T] extends OrdSpecSymmetry H, OrdSpecRel H Eq + +theorem infallible [_root_.Ord T] [OrdSpec H]: ∀ a b, ∃ o, H.cmp a b = .ok o := fun a b => by + obtain ⟨ o, ⟨ H, _ ⟩ ⟩ := OrdSpec.infallible a b + exact ⟨ o, H ⟩ + +instance: Coe (avl.Ordering) (_root_.Ordering) where + coe a := a.toLeanOrdering + +theorem rustCmpEq [_root_.Ord T] [O: OrdSpec H]: H.cmp a b = .ok o <-> compare a b = o.toLeanOrdering := by + apply Iff.intro + . intro Hcmp + obtain ⟨ o', ⟨ Hcmp', Hcompare ⟩ ⟩ := O.infallible a b + rw [Hcmp', ok.injEq] at Hcmp + simp [Hcompare, Hcmp', Hcmp] + . intro Hcompare + obtain ⟨ o', ⟨ Hcmp', Hcompare' ⟩ ⟩ := O.infallible a b + rw [Hcompare', avl.Ordering.toLeanOrdering.injEq] at Hcompare + simp [Hcompare.symm, Hcmp'] + + +theorem oppositeOfOpposite {x y: _root_.Ordering}: x.toDualOrdering = y ↔ x = y.toDualOrdering := by + cases x <;> cases y <;> simp +theorem oppositeRustOrder [_root_.Ord T] [Spec: OrdSpecSymmetry H] {a b}: H.cmp b a = .ok o ↔ H.cmp a b = .ok o.toDualOrdering := by + rw [rustCmpEq, Spec.symmetry, compare, Ord.opposite, oppositeOfOpposite, rustCmpEq, toDualOrderingOfToLeanOrdering] + +theorem ltOfRustOrder + [LO: LinearOrder T] + [Spec: OrdSpec H]: + ∀ a b, H.cmp a b = .ok .Less -> a < b := by + intros a b + intro Hcmp + -- why the typeclass search doesn't work here? + refine' (@compare_lt_iff_lt T LO).1 _ + obtain ⟨ o, ⟨ Hcmp', Hcompare ⟩ ⟩ := Spec.infallible a b + simp only [Hcmp', ok.injEq] at Hcmp + simp [Hcompare, Hcmp, avl.Ordering.toLeanOrdering] + +theorem gtOfRustOrder + [LinearOrder T] + [Spec: OrdSpecSymmetry H]: + ∀ a b, H.cmp a b = .ok .Greater -> b < a := by + intros a b + intro Hcmp + refine' @ltOfRustOrder _ H _ Spec.toOrdSpec _ _ _ + rewrite [oppositeRustOrder] + simp [Hcmp] + +-- TODO: move to standard library +@[simp] +theorem compare_eq_lt_iff [LinOrd : LinearOrder T] (x y : T) : + compare x y = Ordering.lt ↔ x < y := by + simp_all [LinOrd.compare_eq_compareOfLessAndEq, compareOfLessAndEq] + split <;> simp_all + +-- TODO: move to standard library +@[simp] +theorem compare_eq_equal_iff [LinOrd : LinearOrder T] (x y : T) : + compare x y = Ordering.eq ↔ x = y := by + simp_all [LinOrd.compare_eq_compareOfLessAndEq, compareOfLessAndEq] + split <;> simp_all + rw [eq_iff_le_not_lt] + tauto + +-- TODO: move to standard library +@[simp] +theorem compare_eq_gt_iff [LinOrd : LinearOrder T] (x y : T) : + compare x y = Ordering.gt ↔ y < x := by + simp_all [LinOrd.compare_eq_compareOfLessAndEq, compareOfLessAndEq] + split <;> simp_all + . rw [lt_iff_le_not_le] at *; tauto + . constructor + . intro Hneq + apply lt_of_le_of_ne <;> tauto + . intro Hlt + rw [eq_iff_le_not_lt] + simp + intro Hle + have := le_antisymm Hle + simp_all + +end avl diff --git a/tests/lean/Avl/Properties.lean b/tests/lean/Avl/Properties.lean new file mode 100644 index 000000000..5f9b3f653 --- /dev/null +++ b/tests/lean/Avl/Properties.lean @@ -0,0 +1,1030 @@ +import Avl.Funs +import Avl.Types +import Avl.OrderSpec + +namespace avl + +open Primitives Result + +-- TODO: move +@[simp] +def Option.allP {α : Type u} (p : α → Prop) (x : Option α) : Prop := + match x with + | none => true + | some x => p x + +abbrev Subtree (T : Type) := Option (Node T) + +mutual +@[simp] +def Node.height: Node T -> Nat +| Node.mk y left right _ => 1 + max (Subtree.height left) (Subtree.height right) + +@[simp] +def Subtree.height: Subtree T -> Nat +| none => 0 +| some n => Node.height n +end + +mutual +@[simp] +def Node.size: Node T -> Nat +| Node.mk y left right _ => 1 + Subtree.size left + Subtree.size right + +@[simp] +def Subtree.size: Subtree T -> Nat +| none => 0 +| some n => 1 + Node.size n +end + +def Tree.nil: Tree T := { root := none } + +-- Automatic synthesization of this seems possible at the Lean level? +instance: Inhabited (Tree T) where + default := Tree.nil + +instance [Inhabited T]: Inhabited (Node T) where + default := ⟨ Inhabited.default, none, none, 0#i8 ⟩ + +mutual +@[simp, reducible] def Subtree.v (tree: Subtree T) : Set T := + match tree with + | none => ∅ + | some node => node.v + +@[simp, reducible] def Node.v (node : Node T) : Set T := + match node with + | .mk x left right _ => {x} ∪ Subtree.v left ∪ Subtree.v right +end + +@[reducible] +def Tree.v (t: Tree T): Set T := Subtree.v t.root + +mutual +@[simp] def Subtree.forall (p: Node T -> Prop) (st : Subtree T) : Prop := + match st with + | none => true + | some st => st.forall p +termination_by Subtree.size st +decreasing_by simp_wf + +def Node.forall (p: Node T -> Prop) (node : Node T) : Prop := + p node ∧ + Subtree.forall p node.left ∧ Subtree.forall p node.right +termination_by Node.size node +decreasing_by all_goals (simp_wf; simp [Node.left, Node.right]; split <;> simp <;> scalar_tac) +end + +@[simp] +theorem Subtree.forall_left {p: Node T -> Prop} {t: Node T}: Node.forall p t -> Subtree.forall p t.left := by + cases t; simp_all [Node.forall] + +@[simp] +theorem Subtree.forall_right {p: Node T -> Prop} {t: Node T}: Subtree.forall p t -> Subtree.forall p t.right := by + cases t; simp_all [Node.forall] + +mutual +theorem Subtree.forall_imp {p q: Node T -> Prop} {t: Subtree T}: (∀ x, p x -> q x) -> Subtree.forall p t -> Subtree.forall q t + := by + match t with + | none => simp + | some node => + simp + intros + apply @Node.forall_imp T p q <;> tauto + +theorem Node.forall_imp {p q: Node T -> Prop} {t: Node T}: (∀ x, p x -> q x) -> Node.forall p t -> Node.forall q t := by + match t with + | .mk x left right height => + simp [Node.forall] + intros Himp Hleft Hright Hx + simp [*] + split_conjs <;> apply @Subtree.forall_imp T p q <;> tauto + +end + +def Node.balanceFactor (node : Node T) : ℤ := + Subtree.height node.right - Subtree.height node.left + +def Subtree.balanceFactor (t: Subtree T): ℤ := + match t with + | .none => 0 + | .some x => x.balanceFactor + +@[simp] +theorem Subtree.some_balanceFactor (t: Node T) : + Subtree.balanceFactor (some t) = t.balanceFactor := by + simp [balanceFactor] + +@[simp, reducible] +def Node.invAuxNotBalanced [LinearOrder T] (node : Node T) : Prop := + node.balance_factor.val = node.balanceFactor ∧ + (∀ x ∈ Subtree.v node.left, x < node.value) ∧ + (∀ x ∈ Subtree.v node.right, node.value < x) + +def Node.invAux [LinearOrder T] (node : Node T) : Prop := + Node.invAuxNotBalanced node ∧ + -1 ≤ node.balanceFactor ∧ node.balanceFactor ≤ 1 + +@[reducible] +def Node.inv [LinearOrder T] (node : Node T) : Prop := + Node.forall Node.invAux node + +-- TODO: use a custom set +@[aesop safe forward] +theorem Node.inv_imp_current [LinearOrder T] {node : Node T} (hInv : node.inv) : + node.balance_factor.val = node.balanceFactor ∧ + (∀ x ∈ Subtree.v node.left, x < node.value) ∧ + (∀ x ∈ Subtree.v node.right, node.value < x) ∧ + -1 ≤ node.balanceFactor ∧ node.balanceFactor ≤ 1 := by + simp_all [Node.inv, Node.forall, Node.invAux] + +@[reducible] +def Subtree.inv [LinearOrder T] (st : Subtree T) : Prop := + match st with + | none => true + | some node => node.inv + +@[simp] +theorem Subtree.inv_some [LinearOrder T] (s : Node T) : Subtree.inv (some s) = s.inv := by + rfl + +@[reducible] +def Tree.height (t : Tree T) := Subtree.height t.root + +@[reducible] +def Tree.inv [LinearOrder T] (t : Tree T) : Prop := Subtree.inv t.root + +@[simp] +theorem Node.inv_mk [LinearOrder T] (value : T) (left right : Option (Node T)) (bf : I8): + (Node.mk value left right bf).inv ↔ + ((Node.mk value left right bf).invAux ∧ + Subtree.inv left ∧ Subtree.inv right) := by + have : ∀ (n : Option (Node T)), Subtree.forall invAux n = Subtree.inv n := by + unfold Subtree.forall + simp [Subtree.inv] + constructor <;> + simp [*, Node.inv, Node.forall] + +@[simp] +theorem Node.inv_left [LinearOrder T] {t: Node T}: t.inv -> Subtree.inv t.left := by + simp [Node.inv] + intro + cases t; simp_all + +@[simp] +theorem Node.inv_right [LinearOrder T] {t: Node T}: t.inv -> Subtree.inv t.right := by + simp [Node.inv] + intro + cases t; simp_all + +theorem Node.inv_imp_balance_factor_eq [LinearOrder T] {t: Node T} (hInv : t.inv) : + t.balance_factor.val = t.balanceFactor := by + simp [inv, Node.forall, invAux] at hInv + cases t; simp_all + +@[simp] +theorem Node.lt_imp_not_in_right [LinearOrder T] (node : Node T) (hInv : node.inv) (x : T) + (hLt : x < node.value) : + x ∉ Subtree.v node.right := by + have ⟨ _, _, h, _ ⟩ := Node.inv_imp_current hInv + intro hIn + have := h x hIn + have := lt_asymm this + tauto + +@[simp] +theorem Node.lt_imp_not_in_left [LinearOrder T] (node : Node T) (hInv : node.inv) (x : T) + (hLt : node.value < x) : + x ∉ Subtree.v node.left := by + have ⟨ _, h, _, _ ⟩ := Node.inv_imp_current hInv + intro hIn + have := h x hIn + have := lt_asymm this + tauto + +@[simp] +theorem Node.value_not_in_right [LinearOrder T] (node : Node T) (hInv : node.inv) : + node.value ∉ Subtree.v node.right := by + have ⟨ _, _, h, _ ⟩ := Node.inv_imp_current hInv + intro hIn + have := h node.value hIn + have := ne_of_lt this + tauto + +@[simp] +theorem Node.value_not_in_left [LinearOrder T] (node : Node T) (hInv : node.inv) : + node.value ∉ Subtree.v node.left := by + have ⟨ _, h, _, _ ⟩ := Node.inv_imp_current hInv + intro hIn + have := h node.value hIn + have := ne_of_lt this + tauto + +@[pspec] +theorem Tree.find_loop_spec + {T : Type} (OrdInst : Ord T) + [DecidableEq T] [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] + (value : T) (t : Subtree T) (hInv : Subtree.inv t) : + ∃ b, Tree.find_loop T OrdInst value t = ok b ∧ + (b ↔ value ∈ Subtree.v t) := by + rewrite [find_loop] + match t with + | none => simp + | some (.mk v left right height) => + dsimp only + have hCmp := Ospec.infallible -- TODO + progress keep Hordering as ⟨ ordering ⟩; clear hCmp + have hInvLeft := Node.inv_left hInv + have hInvRight := Node.inv_right hInv + cases ordering <;> dsimp only + . /- node.value < value -/ + progress + have hNotIn := Node.lt_imp_not_in_left _ hInv + simp_all + intro; simp_all + . /- node.value = value -/ + simp_all + . /- node.value > value -/ + progress + have hNotIn := Node.lt_imp_not_in_right _ hInv + simp_all + intro; simp_all + +theorem Tree.find_spec + {T : Type} (OrdInst : Ord T) + [DecidableEq T] [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] + (t : Tree T) (value : T) (hInv : t.inv) : + ∃ b, Tree.find T OrdInst t value = ok b ∧ + (b ↔ value ∈ t.v) := by + rw [find] + progress + simp [Tree.v]; assumption + +-- TODO: move +set_option maxHeartbeats 5000000 + +@[pspec] +theorem Node.rotate_left_spec + {T : Type} [LinearOrder T] + (x z : T) (a b c : Option (Node T)) (bf_x bf_z : I8) + -- Invariants for the subtrees + (hInvA : Subtree.inv a) + (hInvZ : Node.inv ⟨ z, b, c, bf_z ⟩) + -- Invariant for the complete tree (but without the bounds on the balancing operation) + (hInvX : Node.invAuxNotBalanced ⟨ x, a, some ⟨ z, b, c, bf_z ⟩, bf_x ⟩) + -- The tree is not balanced + (hBfX : bf_x.val = 2) + -- Z has a positive balance factor + (hBfZ : 0 ≤ bf_z.val) + : + ∃ ntree, rotate_left T ⟨ x, a, none, bf_x ⟩ ⟨ z, b, c, bf_z ⟩ = ok ntree ∧ + let tree : Node T := ⟨ x, a, some ⟨ z, b, c, bf_z ⟩, bf_x ⟩ + -- We reestablished the invariant + Node.inv ntree ∧ + -- The tree contains the nodes we expect + Node.v ntree = Node.v tree ∧ + -- The height is the same as before. The original height is 2 + height b, and by + -- inserting an element (which produced subtree c) we got a new height of + -- 3 + height b; after the rotation we get back to 2 + height b. + Node.height ntree = 2 + Subtree.height b + := by + rw [rotate_left] + simp [core.mem.replace] + -- Some proofs common to both cases + -- Elements in the left subtree are < z + have : ∀ (y : T), (y = x ∨ y ∈ Subtree.v a) ∨ y ∈ Subtree.v b → y < z := by + simp [invAux] at hInvZ + intro y hIn + -- TODO: automate that + cases hIn + . rename _ => hIn + cases hIn + . simp [*] + . -- Proving: y ∈ a → y < z + -- Using: y < x ∧ x < z + rename _ => hIn + have hInv1 : y < x := by tauto + have hInv2 := hInvX.right.right z + simp at hInv2 + apply lt_trans hInv1 hInv2 + . tauto + -- Elements in the right subtree are < z + have : ∀ y ∈ Subtree.v c, z < y := by + simp [invAux] at hInvZ + tauto + -- Two cases depending on whether the BF of Z is 0 or 1 + split + . -- BF(Z) == 0 + simp at * + simp [*] + have hHeightEq : Subtree.height b = Subtree.height c := by + simp_all [balanceFactor, Node.invAux] + -- TODO: scalar_tac fails here (conversion int/nat) + omega + -- TODO: we shouldn't need this: scalar_tac should succeed + have : 1 + Subtree.height c = Subtree.height a + 2 := by + -- TODO: scalar_tac fails here (conversion int/nat) + simp_all [balanceFactor, Node.invAux] + omega + simp_all + split_conjs + . -- Partial invariant for the final tree starting at z + simp [Node.invAux, balanceFactor, *] + split_conjs <;> (try omega) <;> tauto + . -- Partial invariant for the subtree x + simp [Node.invAux, balanceFactor, *] + split_conjs <;> (try omega) <;> simp_all + . -- The sets are the same + apply Set.ext; simp; tauto + . -- The height didn't change + simp [balanceFactor] at * + replace hInvX := hInvX.left + simp_all + scalar_tac + . -- BF(Z) == 1 + rename _ => hNotEq + simp at * + simp [*] + simp_all + have : bf_z.val = 1 := by + simp [Node.invAux] at hInvZ + omega + clear hNotEq hBfZ + have : Subtree.height c = 1 + Subtree.height b := by + simp [balanceFactor, Node.invAux] at * + replace hInvZ := hInvZ.left + omega + have : max (Subtree.height c) (Subtree.height b) = Subtree.height c := by + scalar_tac + -- TODO: we shouldn't need this: scalar_tac should succeed + have : Subtree.height c = 1 + Subtree.height a := by + -- TODO: scalar_tac fails here (conversion int/nat) + simp_all [balanceFactor, Node.invAux] + omega + have : Subtree.height a = Subtree.height b := by + simp_all + split_conjs + . -- Invariant for whole tree (starting at z) + simp [invAux, balanceFactor] + split_conjs <;> (try omega) <;> tauto + . -- Invariant for subtree x + simp [invAux, balanceFactor] + split_conjs <;> (try omega) <;> simp_all + . -- The sets are the same + apply Set.ext; simp; tauto + . -- The height didn't change + simp [balanceFactor] at * + replace hInvX := hInvX.left + simp_all + scalar_tac + +@[pspec] +theorem Node.rotate_right_spec + {T : Type} [LinearOrder T] + (x z : T) (a b c : Option (Node T)) (bf_x bf_z : I8) + -- Invariants for the subtrees + (hInvC : Subtree.inv c) + (hInvZ : Node.inv ⟨ z, a, b, bf_z ⟩) + -- Invariant for the complete tree (but without the bounds on the balancing operation) + (hInvX : Node.invAuxNotBalanced ⟨ x, some ⟨ z, a, b, bf_z ⟩, c, bf_x ⟩) + -- The tree is not balanced + (hBfX : bf_x.val = -2) + -- Z has a positive balance factor + (hBfZ : bf_z.val ≤ 0) + : + ∃ ntree, rotate_right T ⟨ x, none, c, bf_x ⟩ ⟨ z, a, b, bf_z ⟩ = ok ntree ∧ + let tree : Node T := ⟨ x, some ⟨ z, a, b, bf_z ⟩, c, bf_x ⟩ + -- We reestablished the invariant + Node.inv ntree ∧ + -- The tree contains the nodes we expect + Node.v ntree = Node.v tree ∧ + -- The height is the same as before. The original height is 2 + height b, and by + -- inserting an element (which produced subtree c) we got a new height of + -- 3 + height b; after the rotation we get back to 2 + height b. + Node.height ntree = 2 + Subtree.height b + := by + rw [rotate_right] + simp [core.mem.replace] + -- Some proofs common to both cases + -- Elements in the right subtree are > z + have : ∀ (y : T), (y = x ∨ y ∈ Subtree.v b) ∨ y ∈ Subtree.v c → z < y := by + simp [invAux] at * + intro y hIn + -- TODO: automate that + cases hIn + . rename _ => hIn + cases hIn + . simp [*] + . tauto + . -- Proving: y ∈ c → z < y + -- Using: z < x ∧ x < y + have : z < x := by tauto + have : x < y := by tauto + apply lt_trans <;> tauto + -- Elements in the left subtree are < z + have : ∀ y ∈ Subtree.v a, y < z := by + simp_all [invAux] + -- Two cases depending on whether the BF of Z is 0 or 1 + split + . -- BF(Z) == 0 + simp at * + simp [*] + have hHeightEq : Subtree.height a = Subtree.height b := by + simp_all [balanceFactor, Node.invAux] + -- TODO: scalar_tac fails here (conversion int/nat) + omega + -- TODO: we shouldn't need this: scalar_tac should succeed + have : 1 + Subtree.height a = Subtree.height c + 2 := by + -- TODO: scalar_tac fails here (conversion int/nat) + simp_all [balanceFactor, Node.invAux] + omega + simp_all + split_conjs + . -- Partial invariant for the final tree starting at z + simp [Node.invAux, balanceFactor, *] + split_conjs <;> (try omega) <;> tauto + . -- Partial invariant for the subtree x + simp [Node.invAux, balanceFactor, *] + split_conjs <;> (try omega) <;> simp_all + . -- The sets are the same + apply Set.ext; simp; tauto + . -- The height didn't change + simp [balanceFactor] at * + replace hInvX := hInvX.left + simp_all + scalar_tac + . -- BF(Z) == -1 + rename _ => hNotEq + simp at * + simp [*] + simp_all + have : bf_z.val = -1 := by + simp [Node.invAux] at hInvZ + omega + clear hNotEq hBfZ + have : Subtree.height a = 1 + Subtree.height b := by + simp [balanceFactor, Node.invAux] at * + replace hInvZ := hInvZ.left + omega + have : max (Subtree.height a) (Subtree.height b) = Subtree.height a := by + scalar_tac + -- TODO: we shouldn't need this: scalar_tac should succeed + have : Subtree.height a = 1 + Subtree.height c := by + -- TODO: scalar_tac fails here (conversion int/nat) + simp_all [balanceFactor, Node.invAux] + omega + have : Subtree.height c = Subtree.height b := by + simp_all + split_conjs + . -- Invariant for whole tree (starting at z) + simp [invAux, balanceFactor] + split_conjs <;> (try omega) <;> tauto + . -- Invariant for subtree x + simp [invAux, balanceFactor] + split_conjs <;> (try omega) <;> simp_all + . -- The sets are the same + apply Set.ext; simp; tauto + . -- The height didn't change + simp [balanceFactor] at * + replace hInvX := hInvX.left + simp_all + scalar_tac + +@[pspec] +theorem Node.rotate_left_right_spec + {T : Type} [LinearOrder T] + (x y z : T) (bf_x bf_y bf_z : I8) + (a b t0 t1 : Option (Node T)) + -- Invariants for the subtrees + (hInvX : Node.invAuxNotBalanced ⟨ x, some ⟨ z, t0, some ⟨ y, a, b, bf_y ⟩, bf_z ⟩, t1, bf_x ⟩) + (hInvZ : Node.inv ⟨ z, t0, some ⟨ y, a, b, bf_y ⟩, bf_z ⟩) + (hInv1 : Subtree.inv t1) + -- The tree is not balanced + (hBfX : bf_x.val = -2) + -- Z has a positive balance factor + (hBfZ : bf_z.val = 1) + : + let x_tree := ⟨ x, none, t1, bf_x ⟩ + let y_tree := ⟨ y, a, b, bf_y ⟩ + let z_tree := ⟨ z, t0, some y_tree, bf_z ⟩ + let tree : Node T := ⟨ x, some z_tree, t1, bf_x ⟩ + ∃ ntree, rotate_left_right T x_tree z_tree = ok ntree ∧ + -- We reestablished the invariant + Node.inv ntree ∧ + -- The tree contains the nodes we expect + Node.v ntree = Node.v tree ∧ + -- The height is the same as before. The original height is 2 + height a, and by + -- inserting an element (which produced subtree c) we got a new height of + -- 3 + height b; after the rotation we get back to 2 + height b. + Node.height ntree = 2 + Subtree.height t0 + := by + intro x_tree y_tree z_tree tree + simp [rotate_left_right] -- TODO: this inlines the local decls + -- Some facts about the heights and the balance factors + -- TODO: automate that + have : Node.height z_tree = Subtree.height t1 + 2 := by + simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + have : Node.height y_tree = Subtree.height t0 + 1 := by + simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + have : bf_y.val + Subtree.height a = Subtree.height b := by + simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + simp [x_tree, y_tree, z_tree] at * + -- TODO: automate the < proofs + -- Auxiliary proofs for invAux for y + have : ∀ (e : T), (e = z ∨ e ∈ Subtree.v t0) ∨ e ∈ Subtree.v a → e < y := by + intro e hIn + simp [invAux] at * + cases hIn + . rename _ => hIn + -- TODO: those cases are cumbersome + cases hIn + . simp_all + . have : e < z := by tauto + have : z < y := by tauto + apply lt_trans <;> tauto + . tauto + have : ∀ (e : T), (e = x ∨ e ∈ Subtree.v b) ∨ e ∈ Subtree.v t1 → y < e := by + intro e hIn; simp [invAux] at * + cases hIn + . rename _ => hIn + cases hIn + . simp_all + . tauto + . have : y < x := by + replace hInvX := hInvX.right.left y + tauto + have : x < e := by tauto + apply lt_trans <;> tauto + -- Auxiliary proofs for invAux for z + have : ∀ e ∈ Subtree.v t0, e < z := by + intro x hIn; simp [invAux] at * + tauto + have : ∀ e ∈ Subtree.v a, z < e := by + intro e hIn; simp [invAux] at * + replace hInvZ := hInvZ.right.right.left e + tauto + -- Auxiliary proofs for invAux for x + have : ∀ e ∈ Subtree.v b, e < x := by + intro e hIn; simp [invAux] at * + replace hInvX := hInvX.right.left e + tauto + have : ∀ e ∈ Subtree.v t1, x < e := by + intro e hIn; simp [invAux] at * + tauto + -- Case disjunction on the balance factor of Y + split + . -- BF(Y) = 0 + simp [balanceFactor] at * + split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + . -- invAux for y + split_conjs <;> (try omega) <;> (try tauto) + . -- invAux for z + split_conjs <;> (try scalar_tac) <;> tauto + . -- invAux for x + split_conjs <;> (try scalar_tac) <;> tauto + . -- The sets are the same + apply Set.ext; simp [tree, z_tree, y_tree]; tauto + . -- Height + scalar_tac + . split <;> simp + . -- BF(Y) < 0 + have : bf_y.val = -1 := by simp [Node.invAux] at *; omega + simp [balanceFactor] at * + split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + . -- invAux for y + split_conjs <;> (try omega) <;> (try tauto) + . -- invAux for z + split_conjs <;> (try scalar_tac) <;> tauto + . -- invAux for x + split_conjs <;> (try scalar_tac) <;> tauto + . -- The sets are the same + apply Set.ext; simp [tree, z_tree, y_tree]; tauto + . -- Height + scalar_tac + . -- BF(Y) > 0 + have : bf_y.val = 1 := by simp [Node.invAux] at *; omega + split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + . -- invAux for y + split_conjs <;> (try omega) <;> (try tauto) + . -- invAux for z + split_conjs <;> (try scalar_tac) <;> tauto + . -- invAux for x + split_conjs <;> (try scalar_tac) <;> tauto + . -- The sets are the same + apply Set.ext; simp [tree, z_tree, y_tree]; tauto + . -- Height + scalar_tac + +@[pspec] +theorem Node.rotate_right_left_spec + {T : Type} [LinearOrder T] + (x y z : T) (bf_x bf_y bf_z : I8) + (a b t0 t1 : Option (Node T)) + -- Invariants for the subtrees + (hInvX : Node.invAuxNotBalanced ⟨ x, t1, some ⟨ z, some ⟨ y, b, a, bf_y ⟩, t0, bf_z ⟩, bf_x ⟩) + (hInvZ : Node.inv ⟨ z, some ⟨ y, b, a, bf_y ⟩, t0, bf_z ⟩) + (hInv1 : Subtree.inv t1) + -- The tree is not balanced + (hBfX : bf_x.val = 2) + -- Z has a negative balance factor + (hBfZ : bf_z.val = -1) + : + let x_tree := ⟨ x, t1, none, bf_x ⟩ + let y_tree := ⟨ y, b, a, bf_y ⟩ + let z_tree := ⟨ z, some y_tree, t0, bf_z ⟩ + let tree : Node T := ⟨ x, t1, some z_tree, bf_x ⟩ + ∃ ntree, rotate_right_left T x_tree z_tree = ok ntree ∧ + -- We reestablished the invariant + Node.inv ntree ∧ + -- The tree contains the nodes we expect + Node.v ntree = Node.v tree ∧ + -- The height is the same as before. The original height is 2 + height b, and by + -- inserting an element (which produced subtree c) we got a new height of + -- 3 + height b; after the rotation we get back to 2 + height b. + Node.height ntree = 2 + Subtree.height t1 + := by + intro x_tree y_tree z_tree tree + simp [rotate_right_left] -- TODO: this inlines the local decls + -- Some facts about the heights and the balance factors + -- TODO: automate that + have : Node.height z_tree = Subtree.height t1 + 2 := by + simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + have : Node.height y_tree = Subtree.height t0 + 1 := by + simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + have : bf_y.val + Subtree.height b = Subtree.height a := by + simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + simp [x_tree, y_tree, z_tree] at * + -- TODO: automate the < proofs + -- Auxiliary proofs for invAux for y + have : ∀ (e : T), (e = z ∨ e ∈ Subtree.v a) ∨ e ∈ Subtree.v t0 → y < e := by + intro e hIn + simp [invAux] at * + cases hIn + . rename _ => hIn + -- TODO: those cases are cumbersome + cases hIn + . simp_all + . tauto + . have : z < e := by tauto + have : y < z := by tauto + apply lt_trans <;> tauto + have : ∀ (e : T), (e = x ∨ e ∈ Subtree.v t1) ∨ e ∈ Subtree.v b → e < y := by + intro e hIn; simp [invAux] at * + cases hIn + . rename _ => hIn + cases hIn + . simp_all + . have : x < y := by + replace hInvX := hInvX.right.right y + tauto + have : e < x := by tauto + apply lt_trans <;> tauto + . tauto + -- Auxiliary proofs for invAux for z + have : ∀ e ∈ Subtree.v t0, z < e := by + intro x hIn; simp [invAux] at * + tauto + have : ∀ e ∈ Subtree.v a, e < z := by + intro e hIn; simp [invAux] at * + replace hInvZ := hInvZ.right.left e + tauto + -- Auxiliary proofs for invAux for x + have : ∀ e ∈ Subtree.v b, x < e := by + intro e hIn; simp [invAux] at * + replace hInvX := hInvX.right.right e + tauto + have : ∀ e ∈ Subtree.v t1, e < x := by + intro e hIn; simp [invAux] at * + tauto + -- Case disjunction on the balance factor of Y + split + . -- BF(Y) = 0 + simp [balanceFactor] at * + split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + . -- invAux for y + split_conjs <;> (try omega) <;> (try tauto) + . -- invAux for z + split_conjs <;> (try scalar_tac) <;> tauto + . -- invAux for x + split_conjs <;> (try scalar_tac) <;> tauto + . -- The sets are the same + apply Set.ext; simp [tree, z_tree, y_tree]; tauto + . -- Height + scalar_tac + . split <;> simp + . -- BF(Y) > 0 + have : bf_y.val = 1 := by simp [Node.invAux] at *; omega + simp [balanceFactor] at * + split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + . -- invAux for y + split_conjs <;> (try omega) <;> (try tauto) + . -- invAux for z + split_conjs <;> (try scalar_tac) <;> tauto + . -- invAux for x + split_conjs <;> (try scalar_tac) <;> tauto + . -- The sets are the same + apply Set.ext; simp [tree, z_tree, y_tree]; tauto + . -- Height + scalar_tac + . -- BF(Y) < 0 + have : bf_y.val = -1 := by simp [Node.invAux] at *; omega + split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + . -- invAux for y + split_conjs <;> (try omega) <;> (try tauto) + . -- invAux for z + split_conjs <;> (try scalar_tac) <;> tauto + . -- invAux for x + split_conjs <;> (try scalar_tac) <;> tauto + . -- The sets are the same + apply Set.ext; simp [tree, z_tree, y_tree]; tauto + . -- Height + scalar_tac + +-- This rewriting lemma is problematic below +attribute [-simp] Bool.exists_bool + +-- For the proofs of termination +@[simp] +theorem Node.left_height_lt_height (n : Node T) : + Subtree.height n.left < n.height := by + cases n; simp; scalar_tac + +@[simp] +theorem Node.right_height_lt_height (n : Node T) : + Subtree.height n.right < n.height := by + cases n; simp; scalar_tac + +mutual + +@[pspec] +theorem Node.insert_spec + {T : Type} (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] + (node : Node T) (value : T) + (hInv : Node.inv node) : + ∃ b node', Node.insert T OrdInst node value = ok (b, node') ∧ + Node.inv node' ∧ + Node.v node' = Node.v node ∪ {value} ∧ + (if b then node'.height = node.height + 1 else node'.height = node.height) ∧ + -- This is important for some of the proofs + (b → node'.balanceFactor ≠ 0) := by + rw [Node.insert] + have hCmp := Ospec.infallible -- TODO + progress as ⟨ ordering ⟩ + split <;> rename _ => hEq <;> clear hCmp <;> simp at * + . -- value < node.value + progress as ⟨ updt, node', h1, h2 ⟩ + simp_all + . -- value = node.value + cases node; simp_all + . -- node.value < value + progress as ⟨ updt, node', h1, h2 ⟩ + simp_all +termination_by (node.height, 1) +decreasing_by all_goals simp_wf + +@[pspec] +theorem Tree.insert_in_opt_node_spec + {T : Type} (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] + (tree : Option (Node T)) (value : T) + (hInv : Subtree.inv tree) : + ∃ b tree', Tree.insert_in_opt_node T OrdInst tree value = ok (b, tree') ∧ + Subtree.inv tree' ∧ + Subtree.v tree' = Subtree.v tree ∪ {value} ∧ + (if b then Subtree.height tree' = Subtree.height tree + 1 + else Subtree.height tree' = Subtree.height tree) ∧ + (b → Subtree.height tree > 0 → Subtree.balanceFactor tree' ≠ 0) := by + rw [Tree.insert_in_opt_node] + cases hNode : tree <;> simp [hNode] + . -- tree = none + split_conjs + . simp [Node.invAux, Node.balanceFactor] + . simp [Subtree.inv] + . apply Set.ext; simp + . -- tree = some + rename Node T => node + have hNodeInv : Node.inv node := by simp_all + progress as ⟨ updt, tree' ⟩ + simp_all +termination_by (Subtree.height tree, 2) +decreasing_by simp_wf; simp [*] + +-- TODO: any modification triggers the replay of the whole proof +@[pspec] +theorem Node.insert_in_left_spec + {T : Type} (OrdInst : Ord T) + [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] + (node : Node T) (value : T) + (hInv : Node.inv node) + (hLt : value < node.value) : + ∃ b node', Node.insert_in_left T OrdInst node value = ok (b, node') ∧ + Node.inv node' ∧ + Node.v node' = Node.v node ∪ {value} ∧ + (if b then node'.height = node.height + 1 else node'.height = node.height) ∧ + (b → node'.balanceFactor ≠ 0) := by + rw [Node.insert_in_left] + have hInvLeft : Subtree.inv node.left := by cases node; simp_all + progress as ⟨ updt, left_opt' .. ⟩ + split + . -- the height of the subtree changed + have hBalanceFactor : node.balance_factor = node.balanceFactor ∧ + -1 ≤ node.balanceFactor ∧ node.balanceFactor ≤ 1 := by + cases node; simp_all [Node.invAux] + progress as ⟨ i .. ⟩ + split + . -- i = -2 + simp + cases h: left_opt' with + | none => simp_all -- absurd + | some left' => + simp [h] + cases node with | mk x left right balance_factor => + split + . -- rotate_right + -- TODO: fix progress + cases h:left' with | mk z a b bf_z => + progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ + -- TODO: syntax for preconditions + . simp_all + . simp_all + . simp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor] + scalar_tac + . simp_all + . -- End of the proof + simp [*] + split_conjs + . -- set reasoning + simp_all + apply Set.ext; simp + intro x; tauto + . -- height + simp_all [Node.invAux, Node.balanceFactor] + -- This assertion is not necessary for the proof, but it is important that it holds. + -- We can prove it because of the post-conditions `b → node'.balanceFactor ≠ 0` (see above) + have : bf_z.val = -1 := by scalar_tac + scalar_tac + . -- rotate_left_right + simp + cases h:left' with | mk z t0 y bf_z => + cases h: y with + | none => + -- Can't get there + simp_all [Node.balanceFactor, Node.invAux] + | some y => + cases h: y with | mk y a b bf_y => + progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ + -- TODO: syntax for preconditions + . simp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac + . simp_all + . simp_all + . simp_all [Node.invAux, Node.balanceFactor]; scalar_tac + . -- End of the proof + simp [*] + split_conjs + . apply Set.ext; simp_all + intro x; tauto + . simp_all [Node.invAux, Node.balanceFactor] + scalar_tac + . -- i ≠ -2: the height of the tree did not change + simp [*] + split_conjs + . cases node; simp_all [Node.invAux, Node.balanceFactor] + split_conjs <;> scalar_tac + . apply Set.ext; simp + cases node; simp_all + tauto + . simp_all + cases node with | mk node_value left right balance_factor => + split <;> simp [Node.balanceFactor] at * <;> scalar_tac + . simp_all [Node.balanceFactor] + scalar_tac + . -- the height of the subtree did not change + simp [*] + split_conjs + . cases node; + simp_all [Node.invAux, Node.balanceFactor] + . apply Set.ext; simp; intro x + cases node; simp_all + tauto + . simp_all + cases node; simp_all +termination_by (node.height, 0) +decreasing_by simp_wf + +@[pspec] +theorem Node.insert_in_right_spec + {T : Type} (OrdInst : Ord T) + [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] + (node : Node T) (value : T) + (hInv : Node.inv node) + (hGt : value > node.value) : + ∃ b node', Node.insert_in_right T OrdInst node value = ok (b, node') ∧ + Node.inv node' ∧ + Node.v node' = Node.v node ∪ {value} ∧ + (if b then node'.height = node.height + 1 else node'.height = node.height) ∧ + (b → node'.balanceFactor ≠ 0) := by + rw [Node.insert_in_right] + have hInvLeft : Subtree.inv node.right := by cases node; simp_all + progress as ⟨ updt, right_opt' .. ⟩ + split + . -- the height of the subtree changed + have hBalanceFactor : node.balance_factor = node.balanceFactor ∧ + -1 ≤ node.balanceFactor ∧ node.balanceFactor ≤ 1 := by + cases node; simp_all [Node.invAux] + progress as ⟨ i .. ⟩ + split + . -- i = 2 + simp + cases h: right_opt' with + | none => simp_all -- absurd + | some right' => + simp [h] + split + . -- rotate_left + cases node with | mk x a right balance_factor => + -- TODO: fix progress + cases h:right' with | mk z b c bf_z => + progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ + -- TODO: syntax for preconditions + . simp_all + . simp_all + . simp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac + . simp_all + . -- End of the proof + simp [*] + split_conjs + . -- set reasoning + simp_all + . -- height + simp_all [Node.invAux, Node.balanceFactor] + -- This assertion is not necessary for the proof, but it is important that it holds. + -- We can prove it because of the post-conditions `b → node'.balanceFactor ≠ 0` (see above) + have : bf_z.val = 1 := by scalar_tac + scalar_tac + . -- rotate_right_left + cases node with | mk x t1 right balance_factor => + simp + cases h:right' with | mk z y t0 bf_z => + cases h: y with + | none => + -- Can't get there + simp_all [Node.balanceFactor, Node.invAux] + | some y => + cases h: y with | mk y b a bf_y => + progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ + -- TODO: syntax for preconditions + . simp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac + . simp_all + . simp_all + . simp_all [Node.invAux, Node.balanceFactor]; scalar_tac + . -- End of the proof + simp [*] + split_conjs + . apply Set.ext; simp_all + . simp_all [Node.invAux, Node.balanceFactor] + scalar_tac + . -- i ≠ -2: the height of the tree did not change + simp [*] + split_conjs + . cases node; simp_all [Node.invAux, Node.balanceFactor] + split_conjs <;> scalar_tac + . apply Set.ext; simp + cases node; simp_all + . simp_all + cases node with | mk node_value left right balance_factor => + split <;> simp [Node.balanceFactor] at * <;> scalar_tac + . simp_all [Node.balanceFactor] + scalar_tac + . -- the height of the subtree did not change + simp [*] -- TODO: annoying to use this simp everytime: put this in progress + split_conjs + . cases node; + simp_all [Node.invAux, Node.balanceFactor] + . apply Set.ext; simp; intro x + cases node; simp_all + . simp_all + cases node; simp_all +termination_by (node.height, 0) +decreasing_by simp_wf + +end + +@[pspec] +theorem Tree.insert_spec {T : Type} + (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] + (tree : Tree T) (value : T) + (hInv : tree.inv) : + ∃ updt tree', Tree.insert T OrdInst tree value = ok (updt, tree') ∧ + tree'.inv ∧ + (if updt then tree'.height = tree.height + 1 else tree'.height = tree.height) ∧ + tree'.v = tree.v ∪ {value} := by + rw [Tree.insert] + progress as ⟨ updt, tree' ⟩ + simp [*] + +@[pspec] +theorem Tree.new_spec {T : Type} (OrdInst : Ord T) : + ∃ t, Tree.new T OrdInst = ok t ∧ t.v = ∅ ∧ t.height = 0 := by + simp [new, Tree.v, Tree.height] + +end avl diff --git a/tests/lean/Avl/ScalarOrder.lean b/tests/lean/Avl/ScalarOrder.lean new file mode 100644 index 000000000..198df5457 --- /dev/null +++ b/tests/lean/Avl/ScalarOrder.lean @@ -0,0 +1,59 @@ +import Avl.OrderSpec + +namespace avl + +open Primitives + +-- TODO: move +instance ScalarDecidableLE {ty} : DecidableRel (· ≤ · : Scalar ty -> Scalar ty -> Prop) := by + simp [instLEScalar] + -- Lift this to the decidability of the Int version. + infer_instance + +-- TODO: move +instance {ty} : LinearOrder (Scalar ty) where + le_antisymm := fun a b Hab Hba => by + apply (Scalar.eq_equiv a b).2; exact (Int.le_antisymm ((Scalar.le_equiv _ _).1 Hab) ((Scalar.le_equiv _ _).1 Hba)) + le_total := fun a b => by + rcases (Int.le_total a b) with H | H + left; exact (Scalar.le_equiv _ _).2 H + right; exact (Scalar.le_equiv _ _).2 H + decidableLE := ScalarDecidableLE + +instance : OrdSpecLinearOrderEq OrdI32 where + infallible := fun a b => by + unfold Ord.cmp + unfold OrdI32 + unfold OrdI32.cmp + rw [LinearOrder.compare_eq_compareOfLessAndEq, compareOfLessAndEq] + if hlt : a < b then + use .Less + simp [*] + intros; scalar_tac -- Contradiction + else + if heq: a = b + then + use .Equal + simp [hlt, *] + else + use .Greater + simp [hlt, heq] + scalar_tac + symmetry := fun a b => by + simp [Ordering.toDualOrdering, LinearOrder.compare_eq_compareOfLessAndEq, compareOfLessAndEq] + rw [compare, Ord.opposite] + simp [LinearOrder.compare_eq_compareOfLessAndEq, compareOfLessAndEq] + split_ifs with hab hba hba' hab' hba'' _ hba₃ _ <;> tauto + exact lt_irrefl _ (lt_trans hab hba) + rw [hba'] at hab; exact lt_irrefl _ hab + rw [hab'] at hba''; exact lt_irrefl _ hba'' + -- The order is total, therefore, we have at least one case where we are comparing something. + cases (lt_trichotomy a b) <;> tauto + equivalence := fun a b => by + unfold Ord.cmp + unfold OrdI32 + unfold OrdI32.cmp + simp only [] + split_ifs <;> simp only [Result.ok.injEq, not_false_eq_true, neq_imp, IsEmpty.forall_iff]; tauto; try assumption + +end avl diff --git a/tests/lean/Avl/Types.lean b/tests/lean/Avl/Types.lean new file mode 100644 index 000000000..f2820dcb9 --- /dev/null +++ b/tests/lean/Avl/Types.lean @@ -0,0 +1,70 @@ +-- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS +-- [avl]: type definitions +import Base +open Primitives +set_option linter.dupNamespace false +set_option linter.hashCommand false +set_option linter.unusedVariables false + +namespace avl + +/- [avl::Ordering] + Source: 'src/avl.rs', lines 19:0-19:17 -/ +inductive Ordering := +| Less : Ordering +| Equal : Ordering +| Greater : Ordering + +/- Trait declaration: [avl::Ord] + Source: 'src/avl.rs', lines 25:0-25:13 -/ +structure Ord (Self : Type) where + cmp : Self → Self → Result Ordering + +/- [avl::Node] + Source: 'src/avl.rs', lines 29:0-29:14 -/ +inductive Node (T : Type) := +| mk : T → Option (Node T) → Option (Node T) → I8 → Node T + +@[reducible] +def Node.value {T : Type} (x : Node T) := + match x with | Node.mk x1 _ _ _ => x1 + +@[reducible] +def Node.left {T : Type} (x : Node T) := + match x with | Node.mk _ x1 _ _ => x1 + +@[reducible] +def Node.right {T : Type} (x : Node T) := + match x with | Node.mk _ _ x1 _ => x1 + +@[reducible] +def Node.balance_factor {T : Type} (x : Node T) := + match x with | Node.mk _ _ _ x1 => x1 + +@[simp] +theorem Node.value._simpLemma_ {T : Type} (value : T) (left : Option (Node T)) + (right : Option (Node T)) (balance_factor : I8) : + (Node.mk value left right balance_factor).value = value := by rfl + +@[simp] +theorem Node.left._simpLemma_ {T : Type} (value : T) (left : Option (Node T)) + (right : Option (Node T)) (balance_factor : I8) : + (Node.mk value left right balance_factor).left = left := by rfl + +@[simp] +theorem Node.right._simpLemma_ {T : Type} (value : T) (left : Option (Node T)) + (right : Option (Node T)) (balance_factor : I8) : + (Node.mk value left right balance_factor).right = right := by rfl + +@[simp] +theorem Node.balance_factor._simpLemma_ {T : Type} (value : T) (left : Option + (Node T)) (right : Option (Node T)) (balance_factor : I8) : + (Node.mk value left right balance_factor).balance_factor = balance_factor := + by rfl + +/- [avl::Tree] + Source: 'src/avl.rs', lines 36:0-36:18 -/ +structure Tree (T : Type) where + root : Option (Node T) + +end avl diff --git a/tests/lean/Betree/Types.lean b/tests/lean/Betree/Types.lean index 9e7c505b7..9ae80f529 100644 --- a/tests/lean/Betree/Types.lean +++ b/tests/lean/Betree/Types.lean @@ -49,22 +49,42 @@ inductive betree.Node := end -@[simp, reducible] +@[reducible] def betree.Internal.id (x : betree.Internal) := match x with | betree.Internal.mk x1 _ _ _ => x1 -@[simp, reducible] +@[reducible] def betree.Internal.pivot (x : betree.Internal) := match x with | betree.Internal.mk _ x1 _ _ => x1 -@[simp, reducible] +@[reducible] def betree.Internal.left (x : betree.Internal) := match x with | betree.Internal.mk _ _ x1 _ => x1 -@[simp, reducible] +@[reducible] def betree.Internal.right (x : betree.Internal) := match x with | betree.Internal.mk _ _ _ x1 => x1 +@[simp] +theorem betree.Internal.id._simpLemma_ (id : U64) (pivot : U64) (left : + betree.Node) (right : betree.Node) : + (betree.Internal.mk id pivot left right).id = id := by rfl + +@[simp] +theorem betree.Internal.pivot._simpLemma_ (id : U64) (pivot : U64) (left : + betree.Node) (right : betree.Node) : + (betree.Internal.mk id pivot left right).pivot = pivot := by rfl + +@[simp] +theorem betree.Internal.left._simpLemma_ (id : U64) (pivot : U64) (left : + betree.Node) (right : betree.Node) : + (betree.Internal.mk id pivot left right).left = left := by rfl + +@[simp] +theorem betree.Internal.right._simpLemma_ (id : U64) (pivot : U64) (left : + betree.Node) (right : betree.Node) : + (betree.Internal.mk id pivot left right).right = right := by rfl + /- [betree::betree::Params] Source: 'src/betree.rs', lines 187:0-187:13 -/ structure betree.Params where diff --git a/tests/lean/Hashmap/Properties.lean b/tests/lean/Hashmap/Properties.lean index 4329244b5..9913fab71 100644 --- a/tests/lean/Hashmap/Properties.lean +++ b/tests/lean/Hashmap/Properties.lean @@ -291,7 +291,7 @@ theorem insert_in_list_spec_aux {α : Type} (l : Int) (key: Usize) (value: α) ( if h: k = key then rw [insert_in_list] rw [insert_in_list_loop] - simp [h, and_assoc] + simp [h] split_conjs <;> simp_all [slot_s_inv_hash] else rw [insert_in_list] @@ -797,7 +797,7 @@ theorem move_elements_loop_spec progress as ⟨ ntable2, slots2, _, _, hLookup2Rev, hLookup21, hLookup22, hIndexNil ⟩ - simp [and_assoc] + simp have : ∀ (j : ℤ), OfNat.ofNat 0 ≤ j → j < slots2.val.len → slots2.val.index j = AList.Nil := by intro j h0 h1 apply hIndexNil j h0 h1 @@ -816,7 +816,7 @@ theorem move_elements_loop_spec simp_all [alloc.vec.Vec.len, or_assoc] apply hLookupPreserve else - simp [hi, and_assoc, *] + simp [hi, *] simp_all have hi : i = alloc.vec.Vec.len (AList α) slots := by scalar_tac have hEmpty : ∀ j, 0 ≤ j → j < slots.val.len → slots.val.index j = AList.Nil := by @@ -987,7 +987,7 @@ theorem get_mut_in_list_spec {α} (key : Usize) (slot : AList α) simp_all split . -- Non-recursive case - simp_all [and_assoc, slot_t_inv] + simp_all [slot_t_inv] . -- Recursive case -- TODO: progress doesn't instantiate l correctly rename _ → _ → _ => ih @@ -997,11 +997,11 @@ theorem get_mut_in_list_spec {α} (key : Usize) (slot : AList α) -- TODO: progress? notation to have some feedback have ⟨ v, back, hEq, _, hBack ⟩ := ih; clear ih simp [hEq]; clear hEq - simp [and_assoc, *] + simp [*] -- Proving the post-condition about back intro v progress as ⟨ slot', _, _, _, hForAll ⟩; clear hBack - simp [and_assoc, *] + simp [*] constructor . simp_all [slot_t_inv, slot_s_inv, slot_s_inv_hash] . simp_all @@ -1025,7 +1025,7 @@ theorem get_mut_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) (hLook have : slot.lookup key ≠ none := by simp_all [lookup] progress as ⟨ v, back .. ⟩ - simp [and_assoc, lookup, *] + simp [lookup, *] constructor . simp_all . -- Backward function @@ -1053,11 +1053,11 @@ theorem remove_from_list_spec {α} (key : Usize) (slot : AList α) {l i} (hInv : rw [remove_from_list, remove_from_list_loop] match hEq : slot with | .Nil => - simp [and_assoc] + simp | .Cons k v0 tl => simp if hKey : k = key then - simp [hKey, and_assoc] + simp [hKey] simp_all [slot_t_inv, slot_s_inv] apply slot_allP_not_key_lookup simp [*] @@ -1066,7 +1066,7 @@ theorem remove_from_list_spec {α} (key : Usize) (slot : AList α) {l i} (hInv : -- TODO: progress doesn't instantiate l properly have hInv' : slot_t_inv l i tl := by simp_all [slot_t_inv] have ⟨ v1, tl1, hRemove, _, _, hLookupTl1, _ ⟩ := remove_from_list_spec key tl hInv' - simp [and_assoc, *]; clear hRemove + simp [*]; clear hRemove constructor . intro key' hNotEq1 simp_all @@ -1105,7 +1105,7 @@ theorem remove_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : | none => simp [*] progress as ⟨ slot'' ⟩ - simp [and_assoc, lookup, *] + simp [lookup, *] simp_all [al_v, v] intro key' hNotEq -- We need to make a case disjunction @@ -1118,7 +1118,7 @@ theorem remove_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : simp_all [inv] progress as ⟨ newSize .. ⟩ progress as ⟨ slots1 .. ⟩ - simp_all [and_assoc, lookup, al_v, HashMap.v] + simp_all [lookup, al_v, HashMap.v] constructor . intro key' hNotEq cases h: (key.val % hm.slots.val.len) == (key'.val % hm.slots.val.len) <;> diff --git a/tests/lean/Issue194RecursiveStructProjector.lean b/tests/lean/Issue194RecursiveStructProjector.lean index 730d1aa6c..5f96ff71e 100644 --- a/tests/lean/Issue194RecursiveStructProjector.lean +++ b/tests/lean/Issue194RecursiveStructProjector.lean @@ -13,18 +13,33 @@ namespace issue_194_recursive_struct_projector inductive AVLNode (T : Type) := | mk : T → Option (AVLNode T) → Option (AVLNode T) → AVLNode T -@[simp, reducible] +@[reducible] def AVLNode.value {T : Type} (x : AVLNode T) := match x with | AVLNode.mk x1 _ _ => x1 -@[simp, reducible] +@[reducible] def AVLNode.left {T : Type} (x : AVLNode T) := match x with | AVLNode.mk _ x1 _ => x1 -@[simp, reducible] +@[reducible] def AVLNode.right {T : Type} (x : AVLNode T) := match x with | AVLNode.mk _ _ x1 => x1 +@[simp] +theorem AVLNode.value._simpLemma_ {T : Type} (value : T) (left : Option + (AVLNode T)) (right : Option (AVLNode T)) : + (AVLNode.mk value left right).value = value := by rfl + +@[simp] +theorem AVLNode.left._simpLemma_ {T : Type} (value : T) (left : Option (AVLNode + T)) (right : Option (AVLNode T)) : (AVLNode.mk value left right).left = left + := by rfl + +@[simp] +theorem AVLNode.right._simpLemma_ {T : Type} (value : T) (left : Option + (AVLNode T)) (right : Option (AVLNode T)) : + (AVLNode.mk value left right).right = right := by rfl + /- [issue_194_recursive_struct_projector::get_val]: Source: 'tests/src/issue-194-recursive-struct-projector.rs', lines 10:0-10:33 -/ def get_val (T : Type) (x : AVLNode T) : Result T := diff --git a/tests/lean/MiniTree.lean b/tests/lean/MiniTree.lean index 3566d2dbd..8ce962b9d 100644 --- a/tests/lean/MiniTree.lean +++ b/tests/lean/MiniTree.lean @@ -13,9 +13,13 @@ namespace mini_tree inductive Node := | mk : Option Node → Node -@[simp, reducible] +@[reducible] def Node.child (x : Node) := match x with | Node.mk x1 => x1 +@[simp] +theorem Node.child._simpLemma_ (child : Option Node) : + (Node.mk child).child = child := by rfl + /- [mini_tree::Tree] Source: 'tests/src/mini_tree.rs', lines 9:0-9:11 -/ structure Tree where diff --git a/tests/lean/lakefile.lean b/tests/lean/lakefile.lean index ba336a4af..984c1dafc 100644 --- a/tests/lean/lakefile.lean +++ b/tests/lean/lakefile.lean @@ -8,15 +8,16 @@ require base from "../../backends/lean" package «tests» {} -@[default_target] lean_lib Tutorial +@[default_target] lean_lib Arrays +@[default_target] lean_lib Avl @[default_target] lean_lib Betree @[default_target] lean_lib Constants +@[default_target] lean_lib Demo @[default_target] lean_lib External @[default_target] lean_lib Hashmap @[default_target] lean_lib Loops @[default_target] lean_lib NoNestedBorrows @[default_target] lean_lib Paper @[default_target] lean_lib PoloniusList -@[default_target] lean_lib Arrays @[default_target] lean_lib Traits -@[default_target] lean_lib Demo +@[default_target] lean_lib Tutorial diff --git a/tests/src/avl/Cargo.lock b/tests/src/avl/Cargo.lock new file mode 100644 index 000000000..37db529cf --- /dev/null +++ b/tests/src/avl/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "avl" +version = "0.1.0" diff --git a/tests/src/avl/Cargo.toml b/tests/src/avl/Cargo.toml new file mode 100644 index 000000000..c6e05d883 --- /dev/null +++ b/tests/src/avl/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "avl" +version = "0.1.0" +edition = "2021" + +[lib] +name = "avl" +path = "src/avl.rs" \ No newline at end of file diff --git a/tests/src/avl/aeneas-test-options b/tests/src/avl/aeneas-test-options new file mode 100644 index 000000000..a9153114b --- /dev/null +++ b/tests/src/avl/aeneas-test-options @@ -0,0 +1,2 @@ +[!lean] skip +[lean] aeneas-args=-split-files -no-gen-lib-entry \ No newline at end of file diff --git a/tests/src/avl/rust-toolchain b/tests/src/avl/rust-toolchain new file mode 100644 index 000000000..9460b1a82 --- /dev/null +++ b/tests/src/avl/rust-toolchain @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly-2023-06-02" +components = [ "rustc-dev", "llvm-tools-preview" ] diff --git a/tests/src/avl/src/avl.rs b/tests/src/avl/src/avl.rs new file mode 100644 index 000000000..3fe71a811 --- /dev/null +++ b/tests/src/avl/src/avl.rs @@ -0,0 +1,483 @@ +//! Adapted from https://en.wikipedia.org/wiki/AVL_tree +#![feature(register_tool)] +#![register_tool(aeneas)] +#![feature(box_patterns)] +#![feature(let_chains)] + +impl Ord for i32 { + fn cmp(&self, other: &Self) -> Ordering { + if *self < *other { + Ordering::Less + } else if *self == *other { + Ordering::Equal + } else { + Ordering::Greater + } + } +} + +pub enum Ordering { + Less, + Equal, + Greater, +} + +pub trait Ord { + fn cmp(&self, other: &Self) -> Ordering; +} + +struct Node { + value: T, + left: Option>>, + right: Option>>, + balance_factor: i8, +} + +pub struct Tree { + root: Option>>, +} + +impl Node { + fn rotate_left(root: &mut Box>, mut z : Box>) { + // We do (root is X): + // + // X + // / + // A Z + // / \ + // B C + // + // ~> + // + // Z + // / \ + // X C + // / \ + // A B + + let b = std::mem::replace(&mut z.left, None); + root.right = b; + // X + // / \ + // A B + // + // Z + // \ + // C + + let x = std::mem::replace(root, z); + root.left = Some(x); // root is now Z + // Z + // / \ + // X C + // / \ + // A B + + // Update the balance factors + if let Some(x) = &mut root.left { + if root.balance_factor == 0 { + x.balance_factor = 1; + root.balance_factor = -1; + } + else { + x.balance_factor = 0; + root.balance_factor = 0; + } + } + else { + panic!() + } + } + + fn rotate_right(root: &mut Box>, mut z : Box>) { + // We do (root is X): + // + // X + // \ + // Z C + // / \ + // A B + // + // ~> + // + // Z + // / \ + // A X + // / \ + // B C + + let b = std::mem::replace(&mut z.right, None); + root.left = b; + // X + // / \ + // B C + // + // Z + // / + // A + + let x = std::mem::replace(root, z); + root.right = Some(x); // root is now Z + + // Update the balance factors + if let Some(x) = &mut root.right { + if root.balance_factor == 0 { + x.balance_factor = -1; + root.balance_factor = 1; + } + else { + x.balance_factor = 0; + root.balance_factor = 0; + } + } + else { + panic!() + } + } + + fn rotate_left_right(root : &mut Box>, mut z : Box>) { + // We do (root is X): + // + // X + // \ + // Z 1 + // / \ + // 0 Y + // / \ + // A B + // + // ~> + // + // Y + // / \ + // Z X + // / \ / \ + // 0 A B 1 + + let mut y = std::mem::replace(&mut z.right, None).unwrap(); + let a = std::mem::replace(&mut y.left, None); + let b = std::mem::replace(&mut y.right, None); + z.right = a; + root.left = b; + + let x = std::mem::replace(root, y); + root.left = Some(z); + root.right = Some(x); + + // Update the balance factors + if let Some(x) = &mut root.right && let Some(z) = &mut root.left { + if root.balance_factor == 0 { + x.balance_factor = 0; + z.balance_factor = 0; + } + else if root.balance_factor < 0 { + x.balance_factor = 1; + z.balance_factor = 0; + } + else { + x.balance_factor = 0; + z.balance_factor = -1; + } + root.balance_factor = 0; + } + else { + panic!(); + } + } + + fn rotate_right_left(root : &mut Box>, mut z : Box>) { + // We do (root is X): + // + // X + // / + // 1 Z + // / \ + // Y 0 + // / \ + // B A + // + // ~> + // + // Y + // / \ + // X Z + // / \ / \ + // 1 B A 0 + + let mut y = std::mem::replace(&mut z.left, None).unwrap(); + let b = std::mem::replace(&mut y.left, None); + let a = std::mem::replace(&mut y.right, None); + z.left = a; + root.right = b; + + let x = std::mem::replace(root, y); + root.left = Some(x); + root.right = Some(z); + + // Update the balance factors + if let Some(x) = &mut root.left && let Some(z) = &mut root.right { + if root.balance_factor == 0 { + x.balance_factor = 0; + z.balance_factor = 0; + } + else if root.balance_factor > 0 { + x.balance_factor = -1; + z.balance_factor = 0; + } + else { + x.balance_factor = 0; + z.balance_factor = 1; + } + root.balance_factor = 0; + } + else { + panic!(); + } + } +} + +impl Node { + fn insert_in_left(node: &mut Box>, value: T) -> bool { + if Tree::insert_in_opt_node(&mut node.left, value) { + // We increased the height of the left node + node.balance_factor -= 1; + if node.balance_factor == -2 { + // The node is left-heavy: we need to rebalance + let left = std::mem::replace(&mut node.left, Option::None).unwrap(); + if left.balance_factor <= 0 { + // Note that the left balance factor is actually < 0 here + Node::rotate_right(node, left); + } + else { + Node::rotate_left_right(node, left); + } + // In order to udnerstand what happens here we need a drawing. + // In effect, as we rebalanced the tree, the total height did not + // increase compared to before the insertion operation. + false + } + else { + // If the balance factor changed from 1 to 0: the height did not + // change (we increased the height of the left subtree, which + // has now the same height as the right subtree). + // If it changed from 0 to -1: the subtrees originally had the same + // height and now the left subtree is strictly heigher than the + // right subtree: the total height is increased. + // This means that the height changed iff the new balance factor + // is different from 0 + node.balance_factor != 0 + } + } + else { + // The height did not change + false + } + } + + fn insert_in_right(node: &mut Box>, value: T) -> bool { + // Insert in the right sutbree, and check if we increased its height + if Tree::insert_in_opt_node(&mut node.right, value) { + // We increased the height of the right node + node.balance_factor += 1; + + if node.balance_factor == 2 { + // The node is right-heavy: we need to rebalance + let right = std::mem::replace(&mut node.right, Option::None).unwrap(); + if right.balance_factor >= 0 { + // Note that the right balance factor is actually > 0 here + Node::rotate_left(node, right); + } + else { + Node::rotate_right_left(node, right); + } + // In order to udnerstand what happens here we need a drawing. + // In effect, as we rebalanced the tree, the total height did not + // increase compared to before the insertion operation. + false + } + else { + // If the balance factor changed from -1 to 0: the height did not + // change (we increased the height of the right subtree, which + // has now the same height as the left subtree). + // If it changed from 0 to 1: the subtrees originally had the same + // height and now the right subtree is strictly heigher than the + // left subtree: the total height is increased. + // This means that the height changed iff the new balance factor + // is different from 0 + node.balance_factor != 0 + } + } + else { + // The height of the right subtree did not change so the total + // height did not change. + false + } + } + + // Return [true] if we increased the height of the tree + fn insert(node: &mut Box>, value: T) -> bool { + let ordering = value.cmp(&(*node).value); + + // Something important: we need to recompute the balance factor + // from the current balance factor and the change in height. Also, + // we need to compute whether the height of the current tree increased + // or not. + match ordering { + Ordering::Less => { + Node::insert_in_left(node, value) + }, + Ordering::Equal => false, + Ordering::Greater => { + Node::insert_in_right(node, value) + } + } + } +} + +impl Tree { + pub fn new() -> Self { + Self { root: None } + } + + pub fn find(&self, value: T) -> bool { + let mut current_tree = &self.root; + + while let Some(current_node) = current_tree { + match current_node.value.cmp(&value) { + Ordering::Less => current_tree = ¤t_node.right, + Ordering::Equal => return true, + Ordering::Greater => current_tree = ¤t_node.left, + } + } + + false + } + + fn insert_in_opt_node(node: &mut Option>>, value: T) -> bool { + match node { + Some(ref mut node) => { + Node::insert(node, value) + } + None => { + *node = Some(Box::new(Node { + value, + left: None, + right: None, + balance_factor: 0, + })); + true + } + } + } + + /// Insert a value and return [true] if the height of the tree increased. + pub fn insert(&mut self, value: T) -> bool { + Tree::insert_in_opt_node(&mut self.root, value) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl Ord for usize { + fn cmp(&self, other: &Self) -> Ordering { + if *self < *other { + Ordering::Less + } else if *self == *other { + Ordering::Equal + } else { + Ordering::Greater + } + } + } + + fn get_max(a: T, b: T) -> T { + match a.cmp(&b) { + Ordering::Less => b, + Ordering::Equal => b, + Ordering::Greater => a, + } + } + + + impl Node { + fn height(&self) -> usize { + 1 + get_max(self.right.as_deref().map_or(0, |n| n.height()), self.left.as_deref().map_or(0, |n| n.height())) + } + + fn balance_factor(&self) -> isize { + self.right.as_deref().map_or(0, |n| n.height()) as isize - self.left.as_deref().map_or(0, |n| n.height()) as isize + } + + fn check_inv(&self) { + let bf = self.balance_factor(); + assert!(-1 <= bf && bf <= 1); + assert!(bf == self.balance_factor as isize); + if let Some(n) = &self.left { + n.check_inv(); + } + if let Some(n) = &self.right { + n.check_inv(); + } + } + } + + impl Tree { + fn check_inv(&self) { + if let Some(n) = &self.root { + n.check_inv(); + } + } + } + + #[test] + fn test1() { + let mut t : Tree = Tree::new(); + + // Always inserting to the right + for i in 0..100 { + t.insert(i); + t.check_inv(); + } + } + + #[test] + fn test2() { + let mut t : Tree = Tree::new(); + + // Always inserting to the left + for i in 0..(-100) { + t.insert(i); + t.check_inv(); + } + } + + #[test] + fn test3() { + let mut t : Tree = Tree::new(); + + // Always inserting to the right + for i in 0..100 { + t.insert(i); + t.check_inv(); + } + + // Always inserting to the left + for i in 0..(-100) { + t.insert(i); + t.check_inv(); + } + } + + #[test] + fn test4() { + let mut t : Tree = Tree::new(); + + // Simulating randomness here + for i in 0..100 { + t.insert((i * i + 23 * i) % 17); + t.check_inv(); + } + } +} diff --git a/tests/test_runner/run_test.ml b/tests/test_runner/run_test.ml index bac50e84d..3d4e587d7 100644 --- a/tests/test_runner/run_test.ml +++ b/tests/test_runner/run_test.ml @@ -128,25 +128,34 @@ let run_charon (env : runner_env) (case : Input.t) = let args = List.append args case.charon_options in (* Run Charon on the rust file *) Command.run_command_expecting_success (Command.make args) - | Crate -> ( - match Sys.getenv_opt "IN_CI" with - | None -> - let args = - [ env.charon_path; "--dest"; Filename_unix.realpath env.llbc_dir ] - in - let args = List.append args case.charon_options in - (* Run Charon inside the crate *) - let old_pwd = Unix.getcwd () in - Unix.chdir case.path; - Command.run_command_expecting_success (Command.make args); - Unix.chdir old_pwd - | Some _ -> - (* Crates with dependencies must be generated separately in CI. We skip - here and trust that CI takes care to generate the necessary llbc - file. *) - print_endline - "Warn: IN_CI is set; we skip generating llbc files for whole crates" - ) + | Crate -> + (* Because some tests have dependencies which force us to implement custom + treatment in the flake.nix, when in CI, we regenerate files for the crates + only if the .llbc doesn't exist (if it exists, it means it was generated + via a custom derivation in the flake.nix) *) + let generate = + match Sys.getenv_opt "IN_CI" with + | None -> true + | Some _ -> + (* Check if the llbc file already exists *) + let llbc_name = env.llbc_dir ^ "/" ^ case.name ^ ".llbc" in + not (Sys.file_exists llbc_name) + in + if generate then ( + let args = + [ env.charon_path; "--dest"; Filename_unix.realpath env.llbc_dir ] + in + let args = List.append args case.charon_options in + (* Run Charon inside the crate *) + let old_pwd = Unix.getcwd () in + Unix.chdir case.path; + Command.run_command_expecting_success (Command.make args); + Unix.chdir old_pwd) + else + print_endline + ("Warn: crate test \"" ^ case.name + ^ "\": IN_CI is set and the llbc file already exists; we do not \ + regenerate the llbc file for the crate") let () = match Array.to_list Sys.argv with