Skip to content

Commit

Permalink
Lift entire var in MutRef if later mutably borrowed.
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbolliger committed Jun 1, 2021
1 parent 0b6d62d commit 5b0e976
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 22 deletions.
15 changes: 15 additions & 0 deletions stainless_data/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,23 @@ impl Variable<'_> {
pub fn is_mutable(&self) -> bool {
self.flags.iter().any(|f| matches!(f, Flag::IsVar(_)))
}

pub fn is_mut_ref(&self) -> bool {
self
.flags
.iter()
.any(|f| matches!(f, Flag::Annotation(Annotation{name,..}) if name == "wrapped"))
}
}

impl ValDef<'_> {
pub fn is_mutable(&self) -> bool {
self.v.is_mutable()
}

pub fn is_mut_ref(&self) -> bool {
self.v.is_mut_ref()
}
}

impl TypeParameter<'_> {
Expand Down Expand Up @@ -251,4 +262,8 @@ impl Factory {
pub fn evidence_flag(&self) -> Flag<'_> {
self.Annotation("evidence".into(), vec![]).into()
}

pub fn wrapped_flag(&self) -> Flag<'_> {
self.Annotation("wrapped".into(), vec![]).into()
}
}
44 changes: 43 additions & 1 deletion stainless_extraction/src/bindings.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use super::flags::Flags;
use super::*;

use rustc_hir::{self as hir, HirId, Node, Pat, PatKind};
use rustc_hir::{self as hir, def, BorrowKind, HirId, Node, Pat, PatKind};
use rustc_middle::ty;
use rustc_span::symbol::Symbol;

use stainless_data::ast as st;
use std::iter;

impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
/// Build a DefContext that includes all variable bindings
Expand Down Expand Up @@ -188,6 +189,24 @@ impl<'l> DefContext<'l> {
pub(super) fn get_var(&self, hir_id: HirId) -> Option<&'l st::Variable<'l>> {
self.vars.get(&hir_id).copied()
}

pub(super) fn make_mut_ref(&mut self, hir_id: HirId, xtor: &mut BaseExtractor<'l, '_>) {
self.vars.entry(hir_id).and_modify(|v| {
if !v.is_mut_ref() {
let f = xtor.factory();
let tpe = xtor.synth().mut_ref_type(v.tpe);
*v = f.Variable(
v.id,
tpe,
v.flags
.iter()
.copied()
.chain(iter::once(f.wrapped_flag()))
.collect(),
)
}
});
}
}

/// BindingsCollector populates a DefContext
Expand Down Expand Up @@ -220,6 +239,29 @@ impl<'tcx> Visitor<'tcx> for BindingsCollector<'_, '_, '_, 'tcx> {
unreachable!();
}

fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
// If a local variable (path) is mutably borrowed
if let hir::ExprKind::AddrOf(
BorrowKind::Ref,
Mutability::Mut,
hir::Expr {
kind:
hir::ExprKind::Path(hir::QPath::Resolved(
_,
hir::Path {
res: def::Res::Local(id),
..
},
)),
..
},
) = expr.kind
{
self.bxtor.dcx.make_mut_ref(*id, &mut self.bxtor.base)
}
intravisit::walk_expr(self, expr)
}

fn visit_pat(&mut self, pattern: &'tcx hir::Pat<'tcx>) {
match pattern.kind {
hir::PatKind::Binding(_, hir_id, ref _ident, ref optional_subpattern) => {
Expand Down
13 changes: 10 additions & 3 deletions stainless_extraction/src/expr/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,24 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
pattern.span,
),
Ok(vd) => {
let init_expr = self.extract_aliasable_expr(init);
let init_expr = if vd.is_mut_ref() {
let tpe = self.base.extract_ty(init.ty, &self.txtcx, init.span);
self.synth().mut_ref(tpe, init_expr)
} else {
init_expr
};

// recurse the extract all the following statements
let exprs = acc_exprs.clone();
acc_exprs.clear();
let body_expr = self.extract_block_(stmts, acc_exprs, acc_specs, final_expr);

// wrap that body expression into the Let
let init = self.extract_aliasable_expr(init);
let last_expr = if vd.is_mutable() {
f.LetVar(vd, init, body_expr).into()
f.LetVar(vd, init_expr, body_expr).into()
} else {
f.Let(vd, init, body_expr).into()
f.Let(vd, init_expr, body_expr).into()
};
finish(exprs, last_expr)
}
Expand Down
2 changes: 1 addition & 1 deletion stainless_extraction/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {

ExprKind::Tuple { fields } => self.extract_tuple(fields, expr.span),
ExprKind::Field { lhs, name } => self.extract_field(lhs, *name),
ExprKind::VarRef { id } => self.fetch_var(*id).into(),
ExprKind::VarRef { id } => self.extract_var_ref(*id),

ExprKind::Call { ty, ref args, .. } => match ty.kind() {
TyKind::FnDef(def_id, substs_ref) => {
Expand Down
40 changes: 24 additions & 16 deletions stainless_extraction/src/expr/refs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ use super::*;

impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
pub(super) fn extract_deref(&mut self, arg: &'a Expr<'a, 'tcx>) -> st::Expr<'l> {
match arg.ty.ref_mutability() {
Some(Mutability::Mut) => {
let arg = self.extract_expr(arg);
self.synth().mut_ref_value(arg)
}
_ => self.extract_expr(arg),
let expr = self.extract_expr(arg);
if is_mut_ref(arg.ty) {
self.synth().mut_ref_value(expr)
} else {
expr
}
}

Expand All @@ -25,18 +24,27 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
// aliasing is safe.
BorrowKind::Shared => self.extract_aliasable_expr(arg),

BorrowKind::Mut { .. } => match arg.kind {
// Re-borrows a dereferenced mutable borrow, we take the original ref
ExprKind::Deref { arg } if matches!(arg.ty.ref_mutability(), Some(Mutability::Mut)) => {
self.extract_expr(arg)
}
_ => {
let tpe = self.base.extract_ty(arg.ty, &self.txtcx, arg.span);
let a = self.extract_expr(arg);
self.synth().mut_ref(tpe, a)
BorrowKind::Mut { .. } => {
let arg = self.strip_scopes(arg);
match arg.kind {
// Re-borrows a dereferenced mutable borrow, we take the original ref
ExprKind::Deref { arg: inner } if is_mut_ref(inner.ty) => self.extract_expr(inner),

ExprKind::VarRef { id } => self.fetch_var(id).into(),
_ => self.extract_expr(arg),
}
},
}

_ => self.unsupported_expr(arg.span, format!("borrow kind {:?}", borrow_kind)),
}
}

pub(super) fn extract_var_ref(&mut self, id: HirId) -> st::Expr<'l> {
let var = self.fetch_var(id);
if var.is_mut_ref() {
self.synth().mut_ref_value(var.into())
} else {
var.into()
}
}
}
2 changes: 1 addition & 1 deletion stainless_extraction/src/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn uint_bit_width(int_ty: &UintTy, tcx: TyCtxt<'_>) -> u64 {
int_ty.bit_width().unwrap_or_else(|| pointer_bit_width(tcx))
}

pub fn is_mut_ref<'tcx>(ty: Ty<'tcx>) -> bool {
pub fn is_mut_ref(ty: Ty) -> bool {
matches!(ty.ref_mutability(), Some(Mutability::Mut))
}

Expand Down

0 comments on commit 5b0e976

Please sign in to comment.