From 9272b0410664a79146f1d57feed8c4dad155d2e6 Mon Sep 17 00:00:00 2001 From: Yann Bolliger Date: Fri, 4 Jun 2021 15:00:23 +0200 Subject: [PATCH] Support pattern match on mutably borrowed fields. --- stainless_extraction/src/expr/pattern.rs | 64 ++++++++++++++++++------ 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/stainless_extraction/src/expr/pattern.rs b/stainless_extraction/src/expr/pattern.rs index f50b91c1..d05dd335 100644 --- a/stainless_extraction/src/expr/pattern.rs +++ b/stainless_extraction/src/expr/pattern.rs @@ -34,19 +34,14 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { PatKind::Binding { mutability, - mode, var: hir_id, .. - } if matches!( - mode, - BindingMode::ByValue | BindingMode::ByRef(BorrowKind::Shared) - ) => - { + } => { let var = self.fetch_var(*hir_id); if *mutability == Mutability::Not || var.is_mutable() { Ok(self.factory().ValDef(var)) } else { - Err("Binding mode not allowed") + Err("Mutability not allowed") } } @@ -151,7 +146,11 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { TyKind::Tuple(substs) => f .TuplePattern( binder, - self.extract_subpatterns(subpatterns.to_vec(), substs.len()), + self + .extract_subpatterns(subpatterns.to_vec(), substs.len()) + .into_iter() + .map(|p| p.0) + .collect(), ) .into(), @@ -167,7 +166,23 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { }, // TODO: Confirm that rustc introduces this pattern only for primitive derefs - box PatKind::Deref { ref subpattern } => self.extract_pattern(subpattern, binder), + box PatKind::Deref { ref subpattern } => { + let sub = self.extract_pattern(subpattern, binder); + + // If we pattern match on mutable vars, we want the MutCell not it's value + if ty::is_mut_ref(pattern.ty) { + if let st::Type::ADTType(st::ADTType { tps, .. }) = + self.base.extract_ty(pattern.ty, &self.txtcx, pattern.span) + { + f.ADTPattern(None, self.synth().mut_cell_id(), tps.clone(), vec![sub]) + .into() + } else { + sub + } + } else { + sub + } + } _ => self.unsupported_pattern(pattern.span, "Unsupported kind of pattern"), } @@ -197,9 +212,13 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { .base .adt_field_types(adt_def, variant_index, &self.txtcx, substs), ) - .map(|(p, t)| { - f.ADTPattern(None, self.synth().mut_cell_id(), vec![t], vec![p]) - .into() + .map(|((st_pat, thir_pat), t)| { + if is_mut_ref(thir_pat) { + st_pat + } else { + f.ADTPattern(None, self.synth().mut_cell_id(), vec![t], vec![st_pat]) + .into() + } }) .collect(); @@ -210,7 +229,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { &mut self, mut field_pats: Vec>, num_fields: usize, - ) -> Vec> { + ) -> Vec<(st::Pattern<'l>, Option>)> { let f = self.factory(); field_pats.sort_by_key(|field| field.field.index()); field_pats.reverse(); @@ -219,15 +238,28 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { let next = if let Some(FieldPat { field, .. }) = field_pats.last() { if field.index() == i { let FieldPat { pattern, .. } = field_pats.pop().unwrap(); - self.extract_pattern(&pattern, None) + (self.extract_pattern(&pattern, None), Some(pattern)) } else { - f.WildcardPattern(None).into() + (f.WildcardPattern(None).into(), None) } } else { - f.WildcardPattern(None).into() + (f.WildcardPattern(None).into(), None) }; subpatterns.push(next); } subpatterns } } + +fn is_mut_ref<'tcx>(pat: Option>) -> bool { + matches!( + pat, + Some(Pat { + kind: box PatKind::Binding { + mode: BindingMode::ByRef(BorrowKind::Mut { .. }), + .. + }, + .. + }) + ) +}