From b6ebd474d72f6f092e29bc83f1a524cb9674620a Mon Sep 17 00:00:00 2001 From: David Binder Date: Wed, 17 Apr 2024 13:10:56 +0200 Subject: [PATCH] Unify ShiftInRange and Shift traits (#183) * Unify ShiftInRange and Shift traits * Add comments and tests for Shift trait * Add more comments --- lang/elaborator/src/normalizer/env.rs | 2 +- lang/elaborator/src/normalizer/val.rs | 10 +-- lang/elaborator/src/unifier/unify.rs | 4 +- lang/lifting/src/fv.rs | 6 +- lang/printer/src/generic.rs | 24 +++---- lang/syntax/src/common/de_bruijn.rs | 95 ++++++++++++++++++++------- lang/syntax/src/common/subst.rs | 8 +-- lang/syntax/src/ctx/values.rs | 2 +- lang/syntax/src/trees/generic/fold.rs | 8 +-- lang/syntax/src/trees/ust/shift.rs | 12 ++-- lang/syntax/src/trees/ust/subst.rs | 2 +- 11 files changed, 111 insertions(+), 62 deletions(-) diff --git a/lang/elaborator/src/normalizer/env.rs b/lang/elaborator/src/normalizer/env.rs index a704e2436..ea493a5c8 100644 --- a/lang/elaborator/src/normalizer/env.rs +++ b/lang/elaborator/src/normalizer/env.rs @@ -95,7 +95,7 @@ impl From>>> for Env { } } -impl ShiftInRange for Env { +impl Shift for Env { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { self.map(|val| val.shift_in_range(range.clone(), by)) } diff --git a/lang/elaborator/src/normalizer/val.rs b/lang/elaborator/src/normalizer/val.rs index 8b2b40146..77c97eec1 100644 --- a/lang/elaborator/src/normalizer/val.rs +++ b/lang/elaborator/src/normalizer/val.rs @@ -115,7 +115,7 @@ pub struct Closure { pub body: Rc, } -impl ShiftInRange for Val { +impl Shift for Val { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { match self { Val::TypCtor { info, name, args } => Val::TypCtor { @@ -138,7 +138,7 @@ impl ShiftInRange for Val { } } -impl ShiftInRange for Neu { +impl Shift for Neu { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { match self { Neu::Var { info, name, idx } => { @@ -161,14 +161,14 @@ impl ShiftInRange for Neu { } } -impl ShiftInRange for Match { +impl Shift for Match { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let Match { info, cases, omit_absurd } = self; Match { info: *info, cases: cases.shift_in_range(range, by), omit_absurd: *omit_absurd } } } -impl ShiftInRange for Case { +impl Shift for Case { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let Case { info, name, args, body } = self; @@ -181,7 +181,7 @@ impl ShiftInRange for Case { } } -impl ShiftInRange for Closure { +impl Shift for Closure { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let Closure { env, n_args, body } = self; diff --git a/lang/elaborator/src/unifier/unify.rs b/lang/elaborator/src/unifier/unify.rs index b47bcd065..f12e4b62b 100644 --- a/lang/elaborator/src/unifier/unify.rs +++ b/lang/elaborator/src/unifier/unify.rs @@ -26,7 +26,7 @@ impl From<(Rc, Rc)> for Eqn { } } -impl ShiftInRange for Eqn { +impl Shift for Eqn { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let Eqn { lhs, rhs } = self; Eqn { lhs: lhs.shift_in_range(range.clone(), by), rhs: rhs.shift_in_range(range, by) } @@ -44,7 +44,7 @@ impl Substitutable> for Unificator { } } -impl ShiftInRange for Unificator { +impl Shift for Unificator { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { Self { map: self diff --git a/lang/lifting/src/fv.rs b/lang/lifting/src/fv.rs index 9e7f221a8..23cdc34e1 100644 --- a/lang/lifting/src/fv.rs +++ b/lang/lifting/src/fv.rs @@ -301,21 +301,21 @@ impl FVSubst { } } -impl ShiftInRange for FVSubst { +impl Shift for FVSubst { fn shift_in_range(&self, _range: R, _by: (isize, isize)) -> Self { // Since FVSubst works with levels, it is shift-invariant self.clone() } } -impl<'a> ShiftInRange for FVBodySubst<'a> { +impl<'a> Shift for FVBodySubst<'a> { fn shift_in_range(&self, _range: R, _by: (isize, isize)) -> FVBodySubst<'a> { // Since FVSubst works with levels, it is shift-invariant FVBodySubst { inner: self.inner } } } -impl<'a> ShiftInRange for FVParamSubst<'a> { +impl<'a> Shift for FVParamSubst<'a> { fn shift_in_range(&self, _range: R, _by: (isize, isize)) -> FVParamSubst<'a> { // Since FVSubst works with levels, it is shift-invariant FVParamSubst { inner: self.inner } diff --git a/lang/printer/src/generic.rs b/lang/printer/src/generic.rs index facc66f08..67e322907 100644 --- a/lang/printer/src/generic.rs +++ b/lang/printer/src/generic.rs @@ -17,7 +17,7 @@ fn is_visible(attr: &Attribute) -> bool { impl<'a, P: Phase> Print<'a> for Prg

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let Prg { decls } = self; @@ -27,7 +27,7 @@ where impl<'a, P: Phase> Print<'a> for Decls

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let items = @@ -46,7 +46,7 @@ where impl<'a, P: Phase> PrintInCtx<'a> for Decl

where - Exp

: ShiftInRange, + Exp

: Shift, { type Ctx = Decls

; @@ -70,7 +70,7 @@ where impl<'a, P: Phase> PrintInCtx<'a> for Item<'a, P> where - Exp

: ShiftInRange, + Exp

: Shift, { type Ctx = Decls

; @@ -92,7 +92,7 @@ where impl<'a, P: Phase> PrintInCtx<'a> for Data

where - Exp

: ShiftInRange, + Exp

: Shift, { type Ctx = Decls

; @@ -140,7 +140,7 @@ where impl<'a, P: Phase> PrintInCtx<'a> for Codata

where - Exp

: ShiftInRange, + Exp

: Shift, { type Ctx = Decls

; @@ -188,7 +188,7 @@ where impl<'a, P: Phase> Print<'a> for Def

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let Def { info: _, doc, name, attr, params, self_param, ret_typ, body } = self; @@ -216,7 +216,7 @@ where impl<'a, P: Phase> Print<'a> for Codef

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let Codef { info: _, doc, name, attr, params, typ, body } = self; @@ -242,7 +242,7 @@ where impl<'a, P: Phase> Print<'a> for Let

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let Let { info: _, doc, name, attr, params, typ, body } = self; @@ -268,7 +268,7 @@ where impl<'a, P: Phase> Print<'a> for Ctor

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let Ctor { info: _, doc, name, params, typ } = self; @@ -287,7 +287,7 @@ where impl<'a, P: Phase> Print<'a> for Dtor

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let Dtor { info: _, doc, name, params, self_param, ret_typ } = self; @@ -369,7 +369,7 @@ impl<'a, P: Phase> Print<'a> for Case

{ impl<'a, P: Phase> Print<'a> for Telescope

where - Exp

: ShiftInRange, + Exp

: Shift, { fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> { let Telescope { params } = self; diff --git a/lang/syntax/src/common/de_bruijn.rs b/lang/syntax/src/common/de_bruijn.rs index 191f978fb..ce244fdd9 100644 --- a/lang/syntax/src/common/de_bruijn.rs +++ b/lang/syntax/src/common/de_bruijn.rs @@ -77,24 +77,42 @@ pub trait Leveled { } /// De-Bruijn shifting -pub trait Shift { - /// Shift a term in the first component of the two-dimensional De-Bruijn index - fn shift(&self, by: (isize, isize)) -> Self; -} - -pub trait ShiftRange: RangeBounds + Clone {} - -impl + Clone> ShiftRange for T {} +/// +/// When we manipulate terms using de Bruijn notation we often +/// have to change the de Bruijn indices of the variables inside +/// a term. This is what the "shift" and "shift_in_range" functions +/// from this trait are for. +/// +/// Simplified Example: Consider the lambda calculus with de Bruijn +/// indices whose syntax is "e := n | λ_. e | e e". The shift_in_range +/// operation would be defined as follows: +/// - n.shift_in_range(range, by) = if (n ∈ range) then { n + by } else { n } +/// - (λ_. e).shift_in_range(range, by) = λ_.(e.shift_in_range(range.left += 1, by)) +/// - (e1 e2).shift_in_range(range, by) = (e1.shift_in_range(range, by)) (e2.shift_in_range(range, by)) +/// So whenever we traverse a binding occurrence we have to bump the left +/// side of the range by one. +/// +/// Note: We use two-level de Bruijn indices. The cutoff-range only applies to +/// the first element of a two-level de Bruijn index. +/// +/// Ref: https://www.cs.cornell.edu/courses/cs4110/2018fa/lectures/lecture15.pdf +pub trait Shift: Sized { + /// Shift all open variables in `self` by the the value indicated with the + /// `by` argument. + fn shift(&self, by: (isize, isize)) -> Self { + self.shift_in_range(0.., by) + } -pub trait ShiftInRange { + /// Shift every de Bruijn index contained in `self` by the value indicated + /// with the `by` argument. De Bruijn indices whose first component does not + /// lie within the indicated `range` are not affected by the shift. + /// + /// In order to implement `shift_in_range` correctly you have to increase the + /// left endpoint of `range` by 1 whenever you go recursively under a binder. fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self; } -impl ShiftInRange for () { - fn shift_in_range(&self, _range: R, _by: (isize, isize)) -> Self {} -} - -impl ShiftInRange for Idx { +impl Shift for Idx { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { if range.contains(&self.fst) { Self { @@ -107,30 +125,61 @@ impl ShiftInRange for Idx { } } -impl ShiftInRange for Rc { +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn shift_fst() { + let result = Idx { fst: 0, snd: 0 }.shift((1, 0)); + assert_eq!(result, Idx { fst: 1, snd: 0 }); + } + + #[test] + fn shift_snd() { + let result = Idx { fst: 0, snd: 0 }.shift((0, 1)); + assert_eq!(result, Idx { fst: 0, snd: 1 }); + } + + #[test] + fn shift_in_range_fst() { + let result = Idx { fst: 0, snd: 0 }.shift_in_range(1.., (1, 0)); + assert_eq!(result, Idx { fst: 0, snd: 0 }); + } + + #[test] + fn shift_in_range_snd() { + let result = Idx { fst: 0, snd: 0 }.shift_in_range(1.., (0, 1)); + assert_eq!(result, Idx { fst: 0, snd: 0 }); + } +} + +pub trait ShiftRange: RangeBounds + Clone {} + +impl + Clone> ShiftRange for T {} + +impl Shift for () { + fn shift_in_range(&self, _range: R, _by: (isize, isize)) -> Self {} +} + +impl Shift for Rc { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { Rc::new((**self).shift_in_range(range, by)) } } -impl ShiftInRange for Option { +impl Shift for Option { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { self.as_ref().map(|inner| inner.shift_in_range(range, by)) } } -impl ShiftInRange for Vec { +impl Shift for Vec { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { self.iter().map(|x| x.shift_in_range(range.clone(), by)).collect() } } -impl Shift for T { - fn shift(&self, by: (isize, isize)) -> Self { - self.shift_in_range(0.., by) - } -} - pub trait ShiftRangeExt { type Target; diff --git a/lang/syntax/src/common/subst.rs b/lang/syntax/src/common/subst.rs index d2f37d8c2..ba3354884 100644 --- a/lang/syntax/src/common/subst.rs +++ b/lang/syntax/src/common/subst.rs @@ -3,11 +3,11 @@ use crate::ctx::*; pub struct Assign(pub K, pub V); -pub trait Substitution: ShiftInRange { +pub trait Substitution: Shift { fn get_subst(&self, ctx: &LevelCtx, lvl: Lvl) -> Option; } -impl Substitution for Vec> { +impl Substitution for Vec> { fn get_subst(&self, _ctx: &LevelCtx, lvl: Lvl) -> Option { if lvl.fst >= self.len() { return None; @@ -16,13 +16,13 @@ impl Substitution for Vec> { } } -impl ShiftInRange for Assign { +impl Shift for Assign { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { Assign(self.0.clone(), self.1.shift_in_range(range, by)) } } -impl Substitution for Assign { +impl Substitution for Assign { fn get_subst(&self, _ctx: &LevelCtx, lvl: Lvl) -> Option { if self.0 == lvl { Some(self.1.clone()) diff --git a/lang/syntax/src/ctx/values.rs b/lang/syntax/src/ctx/values.rs index e1dc734e6..7da5987b9 100644 --- a/lang/syntax/src/ctx/values.rs +++ b/lang/syntax/src/ctx/values.rs @@ -132,7 +132,7 @@ pub struct Binder { pub typ: Rc, } -impl ShiftInRange for Binder { +impl Shift for Binder { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { Self { name: self.name.clone(), typ: self.typ.shift_in_range(range, by) } } diff --git a/lang/syntax/src/trees/generic/fold.rs b/lang/syntax/src/trees/generic/fold.rs index 5d4a0fb2b..4899e34e3 100644 --- a/lang/syntax/src/trees/generic/fold.rs +++ b/lang/syntax/src/trees/generic/fold.rs @@ -223,7 +223,7 @@ impl> Fold for Vec { impl Fold for Prg

where - P::InfTyp: ShiftInRange, + P::InfTyp: Shift, { type Out = O::Prg; @@ -239,7 +239,7 @@ where impl Fold for Decls

where - P::InfTyp: ShiftInRange, + P::InfTyp: Shift, { type Out = O::Decls; @@ -255,7 +255,7 @@ where impl Fold for Decl

where - P::InfTyp: ShiftInRange, + P::InfTyp: Shift, { type Out = O::Decl; @@ -379,7 +379,7 @@ impl Fold for Dtor

{ impl Fold for Def

where - P::InfTyp: ShiftInRange, + P::InfTyp: Shift, { type Out = O::Def; diff --git a/lang/syntax/src/trees/ust/shift.rs b/lang/syntax/src/trees/ust/shift.rs index 24f5d04f0..98a4fa684 100644 --- a/lang/syntax/src/trees/ust/shift.rs +++ b/lang/syntax/src/trees/ust/shift.rs @@ -2,7 +2,7 @@ use crate::common::*; use super::def::*; -impl ShiftInRange for Exp { +impl Shift for Exp { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { match self { Exp::Var { info, name, ctx: (), idx } => Exp::Var { @@ -52,7 +52,7 @@ impl ShiftInRange for Exp { } } -impl ShiftInRange for Motive { +impl Shift for Motive { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let Motive { info, param, ret_typ } = self; @@ -64,14 +64,14 @@ impl ShiftInRange for Motive { } } -impl ShiftInRange for Match { +impl Shift for Match { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let Match { info, cases, omit_absurd } = self; Match { info: *info, cases: cases.shift_in_range(range, by), omit_absurd: *omit_absurd } } } -impl ShiftInRange for Case { +impl Shift for Case { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let Case { info, name, args, body } = self; Case { @@ -83,14 +83,14 @@ impl ShiftInRange for Case { } } -impl ShiftInRange for TypApp { +impl Shift for TypApp { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { let TypApp { info, name, args } = self; TypApp { info: *info, name: name.clone(), args: args.shift_in_range(range, by) } } } -impl ShiftInRange for Args { +impl Shift for Args { fn shift_in_range(&self, range: R, by: (isize, isize)) -> Self { Self { args: self.args.shift_in_range(range, by) } } diff --git a/lang/syntax/src/trees/ust/subst.rs b/lang/syntax/src/trees/ust/subst.rs index 5986d6a2b..1d26a3faf 100644 --- a/lang/syntax/src/trees/ust/subst.rs +++ b/lang/syntax/src/trees/ust/subst.rs @@ -121,7 +121,7 @@ struct SwapSubst { fst2: usize, } -impl ShiftInRange for SwapSubst { +impl Shift for SwapSubst { fn shift_in_range(&self, _range: R, _by: (isize, isize)) -> Self { // Since SwapSubst works with levels, it is shift-invariant self.clone()