From 5b0e976b7f35a12b57e0492bbbbc4b3d7604bbf9 Mon Sep 17 00:00:00 2001 From: Yann Bolliger Date: Tue, 1 Jun 2021 01:28:11 +0200 Subject: [PATCH] Lift entire var in MutRef if later mutably borrowed. --- stainless_data/src/ast.rs | 15 +++++++++ stainless_extraction/src/bindings.rs | 44 +++++++++++++++++++++++++- stainless_extraction/src/expr/block.rs | 13 ++++++-- stainless_extraction/src/expr/mod.rs | 2 +- stainless_extraction/src/expr/refs.rs | 40 +++++++++++++---------- stainless_extraction/src/ty.rs | 2 +- 6 files changed, 94 insertions(+), 22 deletions(-) diff --git a/stainless_data/src/ast.rs b/stainless_data/src/ast.rs index 5e336e6a..20b6b662 100644 --- a/stainless_data/src/ast.rs +++ b/stainless_data/src/ast.rs @@ -154,12 +154,23 @@ impl Variable<'_> { pub fn is_mutable(&self) -> bool { self.flags.iter().any(|f| matches!(f, Flag::IsVar(_))) } + + pub fn is_mut_ref(&self) -> bool { + self + .flags + .iter() + .any(|f| matches!(f, Flag::Annotation(Annotation{name,..}) if name == "wrapped")) + } } impl ValDef<'_> { pub fn is_mutable(&self) -> bool { self.v.is_mutable() } + + pub fn is_mut_ref(&self) -> bool { + self.v.is_mut_ref() + } } impl TypeParameter<'_> { @@ -251,4 +262,8 @@ impl Factory { pub fn evidence_flag(&self) -> Flag<'_> { self.Annotation("evidence".into(), vec![]).into() } + + pub fn wrapped_flag(&self) -> Flag<'_> { + self.Annotation("wrapped".into(), vec![]).into() + } } diff --git a/stainless_extraction/src/bindings.rs b/stainless_extraction/src/bindings.rs index edeecfcb..900b676a 100644 --- a/stainless_extraction/src/bindings.rs +++ b/stainless_extraction/src/bindings.rs @@ -1,11 +1,12 @@ use super::flags::Flags; use super::*; -use rustc_hir::{self as hir, HirId, Node, Pat, PatKind}; +use rustc_hir::{self as hir, def, BorrowKind, HirId, Node, Pat, PatKind}; use rustc_middle::ty; use rustc_span::symbol::Symbol; use stainless_data::ast as st; +use std::iter; impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { /// Build a DefContext that includes all variable bindings @@ -188,6 +189,24 @@ impl<'l> DefContext<'l> { pub(super) fn get_var(&self, hir_id: HirId) -> Option<&'l st::Variable<'l>> { self.vars.get(&hir_id).copied() } + + pub(super) fn make_mut_ref(&mut self, hir_id: HirId, xtor: &mut BaseExtractor<'l, '_>) { + self.vars.entry(hir_id).and_modify(|v| { + if !v.is_mut_ref() { + let f = xtor.factory(); + let tpe = xtor.synth().mut_ref_type(v.tpe); + *v = f.Variable( + v.id, + tpe, + v.flags + .iter() + .copied() + .chain(iter::once(f.wrapped_flag())) + .collect(), + ) + } + }); + } } /// BindingsCollector populates a DefContext @@ -220,6 +239,29 @@ impl<'tcx> Visitor<'tcx> for BindingsCollector<'_, '_, '_, 'tcx> { unreachable!(); } + fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) { + // If a local variable (path) is mutably borrowed + if let hir::ExprKind::AddrOf( + BorrowKind::Ref, + Mutability::Mut, + hir::Expr { + kind: + hir::ExprKind::Path(hir::QPath::Resolved( + _, + hir::Path { + res: def::Res::Local(id), + .. + }, + )), + .. + }, + ) = expr.kind + { + self.bxtor.dcx.make_mut_ref(*id, &mut self.bxtor.base) + } + intravisit::walk_expr(self, expr) + } + fn visit_pat(&mut self, pattern: &'tcx hir::Pat<'tcx>) { match pattern.kind { hir::PatKind::Binding(_, hir_id, ref _ident, ref optional_subpattern) => { diff --git a/stainless_extraction/src/expr/block.rs b/stainless_extraction/src/expr/block.rs index af200ef6..90c15914 100644 --- a/stainless_extraction/src/expr/block.rs +++ b/stainless_extraction/src/expr/block.rs @@ -120,17 +120,24 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { pattern.span, ), Ok(vd) => { + let init_expr = self.extract_aliasable_expr(init); + let init_expr = if vd.is_mut_ref() { + let tpe = self.base.extract_ty(init.ty, &self.txtcx, init.span); + self.synth().mut_ref(tpe, init_expr) + } else { + init_expr + }; + // recurse the extract all the following statements let exprs = acc_exprs.clone(); acc_exprs.clear(); let body_expr = self.extract_block_(stmts, acc_exprs, acc_specs, final_expr); // wrap that body expression into the Let - let init = self.extract_aliasable_expr(init); let last_expr = if vd.is_mutable() { - f.LetVar(vd, init, body_expr).into() + f.LetVar(vd, init_expr, body_expr).into() } else { - f.Let(vd, init, body_expr).into() + f.Let(vd, init_expr, body_expr).into() }; finish(exprs, last_expr) } diff --git a/stainless_extraction/src/expr/mod.rs b/stainless_extraction/src/expr/mod.rs index 059d4e47..4bbf04f1 100644 --- a/stainless_extraction/src/expr/mod.rs +++ b/stainless_extraction/src/expr/mod.rs @@ -36,7 +36,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { ExprKind::Tuple { fields } => self.extract_tuple(fields, expr.span), ExprKind::Field { lhs, name } => self.extract_field(lhs, *name), - ExprKind::VarRef { id } => self.fetch_var(*id).into(), + ExprKind::VarRef { id } => self.extract_var_ref(*id), ExprKind::Call { ty, ref args, .. } => match ty.kind() { TyKind::FnDef(def_id, substs_ref) => { diff --git a/stainless_extraction/src/expr/refs.rs b/stainless_extraction/src/expr/refs.rs index b91328a4..f0c95468 100644 --- a/stainless_extraction/src/expr/refs.rs +++ b/stainless_extraction/src/expr/refs.rs @@ -4,12 +4,11 @@ use super::*; impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { pub(super) fn extract_deref(&mut self, arg: &'a Expr<'a, 'tcx>) -> st::Expr<'l> { - match arg.ty.ref_mutability() { - Some(Mutability::Mut) => { - let arg = self.extract_expr(arg); - self.synth().mut_ref_value(arg) - } - _ => self.extract_expr(arg), + let expr = self.extract_expr(arg); + if is_mut_ref(arg.ty) { + self.synth().mut_ref_value(expr) + } else { + expr } } @@ -25,18 +24,27 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { // aliasing is safe. BorrowKind::Shared => self.extract_aliasable_expr(arg), - BorrowKind::Mut { .. } => match arg.kind { - // Re-borrows a dereferenced mutable borrow, we take the original ref - ExprKind::Deref { arg } if matches!(arg.ty.ref_mutability(), Some(Mutability::Mut)) => { - self.extract_expr(arg) - } - _ => { - let tpe = self.base.extract_ty(arg.ty, &self.txtcx, arg.span); - let a = self.extract_expr(arg); - self.synth().mut_ref(tpe, a) + BorrowKind::Mut { .. } => { + let arg = self.strip_scopes(arg); + match arg.kind { + // Re-borrows a dereferenced mutable borrow, we take the original ref + ExprKind::Deref { arg: inner } if is_mut_ref(inner.ty) => self.extract_expr(inner), + + ExprKind::VarRef { id } => self.fetch_var(id).into(), + _ => self.extract_expr(arg), } - }, + } + _ => self.unsupported_expr(arg.span, format!("borrow kind {:?}", borrow_kind)), } } + + pub(super) fn extract_var_ref(&mut self, id: HirId) -> st::Expr<'l> { + let var = self.fetch_var(id); + if var.is_mut_ref() { + self.synth().mut_ref_value(var.into()) + } else { + var.into() + } + } } diff --git a/stainless_extraction/src/ty.rs b/stainless_extraction/src/ty.rs index 521a5793..45d73f2d 100644 --- a/stainless_extraction/src/ty.rs +++ b/stainless_extraction/src/ty.rs @@ -67,7 +67,7 @@ pub fn uint_bit_width(int_ty: &UintTy, tcx: TyCtxt<'_>) -> u64 { int_ty.bit_width().unwrap_or_else(|| pointer_bit_width(tcx)) } -pub fn is_mut_ref<'tcx>(ty: Ty<'tcx>) -> bool { +pub fn is_mut_ref(ty: Ty) -> bool { matches!(ty.ref_mutability(), Some(Mutability::Mut)) }