Skip to content

Commit

Permalink
Unify ShiftInRange and Shift traits (#183)
Browse files Browse the repository at this point in the history
* Unify ShiftInRange and Shift traits

* Add comments and tests for Shift trait

* Add more comments
  • Loading branch information
BinderDavid authored Apr 17, 2024
1 parent 8447815 commit b6ebd47
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 62 deletions.
2 changes: 1 addition & 1 deletion lang/elaborator/src/normalizer/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl From<Vec<Vec<Rc<Val>>>> for Env {
}
}

impl ShiftInRange for Env {
impl Shift for Env {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
self.map(|val| val.shift_in_range(range.clone(), by))
}
Expand Down
10 changes: 5 additions & 5 deletions lang/elaborator/src/normalizer/val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub struct Closure {
pub body: Rc<ust::Exp>,
}

impl ShiftInRange for Val {
impl Shift for Val {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
match self {
Val::TypCtor { info, name, args } => Val::TypCtor {
Expand All @@ -138,7 +138,7 @@ impl ShiftInRange for Val {
}
}

impl ShiftInRange for Neu {
impl Shift for Neu {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
match self {
Neu::Var { info, name, idx } => {
Expand All @@ -161,14 +161,14 @@ impl ShiftInRange for Neu {
}
}

impl ShiftInRange for Match {
impl Shift for Match {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Match { info, cases, omit_absurd } = self;
Match { info: *info, cases: cases.shift_in_range(range, by), omit_absurd: *omit_absurd }
}
}

impl ShiftInRange for Case {
impl Shift for Case {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Case { info, name, args, body } = self;

Expand All @@ -181,7 +181,7 @@ impl ShiftInRange for Case {
}
}

impl ShiftInRange for Closure {
impl Shift for Closure {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Closure { env, n_args, body } = self;

Expand Down
4 changes: 2 additions & 2 deletions lang/elaborator/src/unifier/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl From<(Rc<ust::Exp>, Rc<ust::Exp>)> for Eqn {
}
}

impl ShiftInRange for Eqn {
impl Shift for Eqn {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Eqn { lhs, rhs } = self;
Eqn { lhs: lhs.shift_in_range(range.clone(), by), rhs: rhs.shift_in_range(range, by) }
Expand All @@ -44,7 +44,7 @@ impl Substitutable<Rc<ust::Exp>> for Unificator {
}
}

impl ShiftInRange for Unificator {
impl Shift for Unificator {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Self {
map: self
Expand Down
6 changes: 3 additions & 3 deletions lang/lifting/src/fv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,21 @@ impl FVSubst {
}
}

impl ShiftInRange for FVSubst {
impl Shift for FVSubst {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> Self {
// Since FVSubst works with levels, it is shift-invariant
self.clone()
}
}

impl<'a> ShiftInRange for FVBodySubst<'a> {
impl<'a> Shift for FVBodySubst<'a> {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> FVBodySubst<'a> {
// Since FVSubst works with levels, it is shift-invariant
FVBodySubst { inner: self.inner }
}
}

impl<'a> ShiftInRange for FVParamSubst<'a> {
impl<'a> Shift for FVParamSubst<'a> {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> FVParamSubst<'a> {
// Since FVSubst works with levels, it is shift-invariant
FVParamSubst { inner: self.inner }
Expand Down
24 changes: 12 additions & 12 deletions lang/printer/src/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn is_visible(attr: &Attribute) -> bool {

impl<'a, P: Phase> Print<'a> for Prg<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Prg { decls } = self;
Expand All @@ -27,7 +27,7 @@ where

impl<'a, P: Phase> Print<'a> for Decls<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let items =
Expand All @@ -46,7 +46,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Decl<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand All @@ -70,7 +70,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Item<'a, P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand All @@ -92,7 +92,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Data<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand Down Expand Up @@ -140,7 +140,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Codata<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand Down Expand Up @@ -188,7 +188,7 @@ where

impl<'a, P: Phase> Print<'a> for Def<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Def { info: _, doc, name, attr, params, self_param, ret_typ, body } = self;
Expand Down Expand Up @@ -216,7 +216,7 @@ where

impl<'a, P: Phase> Print<'a> for Codef<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Codef { info: _, doc, name, attr, params, typ, body } = self;
Expand All @@ -242,7 +242,7 @@ where

impl<'a, P: Phase> Print<'a> for Let<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Let { info: _, doc, name, attr, params, typ, body } = self;
Expand All @@ -268,7 +268,7 @@ where

impl<'a, P: Phase> Print<'a> for Ctor<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Ctor { info: _, doc, name, params, typ } = self;
Expand All @@ -287,7 +287,7 @@ where

impl<'a, P: Phase> Print<'a> for Dtor<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Dtor { info: _, doc, name, params, self_param, ret_typ } = self;
Expand Down Expand Up @@ -369,7 +369,7 @@ impl<'a, P: Phase> Print<'a> for Case<P> {

impl<'a, P: Phase> Print<'a> for Telescope<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Telescope { params } = self;
Expand Down
95 changes: 72 additions & 23 deletions lang/syntax/src/common/de_bruijn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,42 @@ pub trait Leveled {
}

/// De-Bruijn shifting
pub trait Shift {
/// Shift a term in the first component of the two-dimensional De-Bruijn index
fn shift(&self, by: (isize, isize)) -> Self;
}

pub trait ShiftRange: RangeBounds<usize> + Clone {}

impl<T: RangeBounds<usize> + Clone> ShiftRange for T {}
///
/// When we manipulate terms using de Bruijn notation we often
/// have to change the de Bruijn indices of the variables inside
/// a term. This is what the "shift" and "shift_in_range" functions
/// from this trait are for.
///
/// Simplified Example: Consider the lambda calculus with de Bruijn
/// indices whose syntax is "e := n | λ_. e | e e". The shift_in_range
/// operation would be defined as follows:
/// - n.shift_in_range(range, by) = if (n ∈ range) then { n + by } else { n }
/// - (λ_. e).shift_in_range(range, by) = λ_.(e.shift_in_range(range.left += 1, by))
/// - (e1 e2).shift_in_range(range, by) = (e1.shift_in_range(range, by)) (e2.shift_in_range(range, by))
/// So whenever we traverse a binding occurrence we have to bump the left
/// side of the range by one.
///
/// Note: We use two-level de Bruijn indices. The cutoff-range only applies to
/// the first element of a two-level de Bruijn index.
///
/// Ref: https://www.cs.cornell.edu/courses/cs4110/2018fa/lectures/lecture15.pdf
pub trait Shift: Sized {
/// Shift all open variables in `self` by the the value indicated with the
/// `by` argument.
fn shift(&self, by: (isize, isize)) -> Self {
self.shift_in_range(0.., by)
}

pub trait ShiftInRange {
/// Shift every de Bruijn index contained in `self` by the value indicated
/// with the `by` argument. De Bruijn indices whose first component does not
/// lie within the indicated `range` are not affected by the shift.
///
/// In order to implement `shift_in_range` correctly you have to increase the
/// left endpoint of `range` by 1 whenever you go recursively under a binder.
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self;
}

impl ShiftInRange for () {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> Self {}
}

impl ShiftInRange for Idx {
impl Shift for Idx {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
if range.contains(&self.fst) {
Self {
Expand All @@ -107,30 +125,61 @@ impl ShiftInRange for Idx {
}
}

impl<T: ShiftInRange> ShiftInRange for Rc<T> {
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn shift_fst() {
let result = Idx { fst: 0, snd: 0 }.shift((1, 0));
assert_eq!(result, Idx { fst: 1, snd: 0 });
}

#[test]
fn shift_snd() {
let result = Idx { fst: 0, snd: 0 }.shift((0, 1));
assert_eq!(result, Idx { fst: 0, snd: 1 });
}

#[test]
fn shift_in_range_fst() {
let result = Idx { fst: 0, snd: 0 }.shift_in_range(1.., (1, 0));
assert_eq!(result, Idx { fst: 0, snd: 0 });
}

#[test]
fn shift_in_range_snd() {
let result = Idx { fst: 0, snd: 0 }.shift_in_range(1.., (0, 1));
assert_eq!(result, Idx { fst: 0, snd: 0 });
}
}

pub trait ShiftRange: RangeBounds<usize> + Clone {}

impl<T: RangeBounds<usize> + Clone> ShiftRange for T {}

impl Shift for () {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> Self {}
}

impl<T: Shift> Shift for Rc<T> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Rc::new((**self).shift_in_range(range, by))
}
}

impl<T: ShiftInRange> ShiftInRange for Option<T> {
impl<T: Shift> Shift for Option<T> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
self.as_ref().map(|inner| inner.shift_in_range(range, by))
}
}

impl<T: ShiftInRange> ShiftInRange for Vec<T> {
impl<T: Shift> Shift for Vec<T> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
self.iter().map(|x| x.shift_in_range(range.clone(), by)).collect()
}
}

impl<T: ShiftInRange> Shift for T {
fn shift(&self, by: (isize, isize)) -> Self {
self.shift_in_range(0.., by)
}
}

pub trait ShiftRangeExt {
type Target;

Expand Down
8 changes: 4 additions & 4 deletions lang/syntax/src/common/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use crate::ctx::*;

pub struct Assign<K, V>(pub K, pub V);

pub trait Substitution<E>: ShiftInRange {
pub trait Substitution<E>: Shift {
fn get_subst(&self, ctx: &LevelCtx, lvl: Lvl) -> Option<E>;
}

impl<E: Clone + ShiftInRange> Substitution<E> for Vec<Vec<E>> {
impl<E: Clone + Shift> Substitution<E> for Vec<Vec<E>> {
fn get_subst(&self, _ctx: &LevelCtx, lvl: Lvl) -> Option<E> {
if lvl.fst >= self.len() {
return None;
Expand All @@ -16,13 +16,13 @@ impl<E: Clone + ShiftInRange> Substitution<E> for Vec<Vec<E>> {
}
}

impl<K: Clone, V: ShiftInRange> ShiftInRange for Assign<K, V> {
impl<K: Clone, V: Shift> Shift for Assign<K, V> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Assign(self.0.clone(), self.1.shift_in_range(range, by))
}
}

impl<E: Clone + ShiftInRange> Substitution<E> for Assign<Lvl, E> {
impl<E: Clone + Shift> Substitution<E> for Assign<Lvl, E> {
fn get_subst(&self, _ctx: &LevelCtx, lvl: Lvl) -> Option<E> {
if self.0 == lvl {
Some(self.1.clone())
Expand Down
2 changes: 1 addition & 1 deletion lang/syntax/src/ctx/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub struct Binder {
pub typ: Rc<ust::Exp>,
}

impl ShiftInRange for Binder {
impl Shift for Binder {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Self { name: self.name.clone(), typ: self.typ.shift_in_range(range, by) }
}
Expand Down
Loading

0 comments on commit b6ebd47

Please sign in to comment.