Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify ShiftInRange and Shift traits #183

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lang/elaborator/src/normalizer/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl From<Vec<Vec<Rc<Val>>>> for Env {
}
}

impl ShiftInRange for Env {
impl Shift for Env {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
self.map(|val| val.shift_in_range(range.clone(), by))
}
Expand Down
10 changes: 5 additions & 5 deletions lang/elaborator/src/normalizer/val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub struct Closure {
pub body: Rc<ust::Exp>,
}

impl ShiftInRange for Val {
impl Shift for Val {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
match self {
Val::TypCtor { info, name, args } => Val::TypCtor {
Expand All @@ -138,7 +138,7 @@ impl ShiftInRange for Val {
}
}

impl ShiftInRange for Neu {
impl Shift for Neu {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
match self {
Neu::Var { info, name, idx } => {
Expand All @@ -161,14 +161,14 @@ impl ShiftInRange for Neu {
}
}

impl ShiftInRange for Match {
impl Shift for Match {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Match { info, cases, omit_absurd } = self;
Match { info: *info, cases: cases.shift_in_range(range, by), omit_absurd: *omit_absurd }
}
}

impl ShiftInRange for Case {
impl Shift for Case {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Case { info, name, args, body } = self;

Expand All @@ -181,7 +181,7 @@ impl ShiftInRange for Case {
}
}

impl ShiftInRange for Closure {
impl Shift for Closure {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Closure { env, n_args, body } = self;

Expand Down
4 changes: 2 additions & 2 deletions lang/elaborator/src/unifier/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl From<(Rc<ust::Exp>, Rc<ust::Exp>)> for Eqn {
}
}

impl ShiftInRange for Eqn {
impl Shift for Eqn {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
let Eqn { lhs, rhs } = self;
Eqn { lhs: lhs.shift_in_range(range.clone(), by), rhs: rhs.shift_in_range(range, by) }
Expand All @@ -44,7 +44,7 @@ impl Substitutable<Rc<ust::Exp>> for Unificator {
}
}

impl ShiftInRange for Unificator {
impl Shift for Unificator {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Self {
map: self
Expand Down
6 changes: 3 additions & 3 deletions lang/lifting/src/fv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,21 @@ impl FVSubst {
}
}

impl ShiftInRange for FVSubst {
impl Shift for FVSubst {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> Self {
// Since FVSubst works with levels, it is shift-invariant
self.clone()
}
}

impl<'a> ShiftInRange for FVBodySubst<'a> {
impl<'a> Shift for FVBodySubst<'a> {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> FVBodySubst<'a> {
// Since FVSubst works with levels, it is shift-invariant
FVBodySubst { inner: self.inner }
}
}

impl<'a> ShiftInRange for FVParamSubst<'a> {
impl<'a> Shift for FVParamSubst<'a> {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> FVParamSubst<'a> {
// Since FVSubst works with levels, it is shift-invariant
FVParamSubst { inner: self.inner }
Expand Down
24 changes: 12 additions & 12 deletions lang/printer/src/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn is_visible(attr: &Attribute) -> bool {

impl<'a, P: Phase> Print<'a> for Prg<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Prg { decls } = self;
Expand All @@ -27,7 +27,7 @@ where

impl<'a, P: Phase> Print<'a> for Decls<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let items =
Expand All @@ -46,7 +46,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Decl<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand All @@ -70,7 +70,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Item<'a, P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand All @@ -92,7 +92,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Data<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand Down Expand Up @@ -140,7 +140,7 @@ where

impl<'a, P: Phase> PrintInCtx<'a> for Codata<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
type Ctx = Decls<P>;

Expand Down Expand Up @@ -188,7 +188,7 @@ where

impl<'a, P: Phase> Print<'a> for Def<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Def { info: _, doc, name, attr, params, self_param, ret_typ, body } = self;
Expand Down Expand Up @@ -216,7 +216,7 @@ where

impl<'a, P: Phase> Print<'a> for Codef<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Codef { info: _, doc, name, attr, params, typ, body } = self;
Expand All @@ -242,7 +242,7 @@ where

impl<'a, P: Phase> Print<'a> for Let<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Let { info: _, doc, name, attr, params, typ, body } = self;
Expand All @@ -268,7 +268,7 @@ where

impl<'a, P: Phase> Print<'a> for Ctor<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Ctor { info: _, doc, name, params, typ } = self;
Expand All @@ -287,7 +287,7 @@ where

impl<'a, P: Phase> Print<'a> for Dtor<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Dtor { info: _, doc, name, params, self_param, ret_typ } = self;
Expand Down Expand Up @@ -369,7 +369,7 @@ impl<'a, P: Phase> Print<'a> for Case<P> {

impl<'a, P: Phase> Print<'a> for Telescope<P>
where
Exp<P>: ShiftInRange,
Exp<P>: Shift,
{
fn print(&'a self, cfg: &PrintCfg, alloc: &'a Alloc<'a>) -> Builder<'a> {
let Telescope { params } = self;
Expand Down
95 changes: 72 additions & 23 deletions lang/syntax/src/common/de_bruijn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,42 @@ pub trait Leveled {
}

/// De-Bruijn shifting
pub trait Shift {
/// Shift a term in the first component of the two-dimensional De-Bruijn index
fn shift(&self, by: (isize, isize)) -> Self;
}

pub trait ShiftRange: RangeBounds<usize> + Clone {}

impl<T: RangeBounds<usize> + Clone> ShiftRange for T {}
///
/// When we manipulate terms using de Bruijn notation we often
/// have to change the de Bruijn indices of the variables inside
/// a term. This is what the "shift" and "shift_in_range" functions
/// from this trait are for.
///
/// Simplified Example: Consider the lambda calculus with de Bruijn
/// indices whose syntax is "e := n | λ_. e | e e". The shift_in_range
/// operation would be defined as follows:
/// - n.shift_in_range(range, by) = if (n ∈ range) then { n + by } else { n }
/// - (λ_. e).shift_in_range(range, by) = λ_.(e.shift_in_range(range.left += 1, by))
/// - (e1 e2).shift_in_range(range, by) = (e1.shift_in_range(range, by)) (e2.shift_in_range(range, by))
/// So whenever we traverse a binding occurrence we have to bump the left
/// side of the range by one.
///
/// Note: We use two-level de Bruijn indices. The cutoff-range only applies to
/// the first element of a two-level de Bruijn index.
///
/// Ref: https://www.cs.cornell.edu/courses/cs4110/2018fa/lectures/lecture15.pdf
pub trait Shift: Sized {
/// Shift all open variables in `self` by the the value indicated with the
/// `by` argument.
fn shift(&self, by: (isize, isize)) -> Self {
self.shift_in_range(0.., by)
}

pub trait ShiftInRange {
/// Shift every de Bruijn index contained in `self` by the value indicated
/// with the `by` argument. De Bruijn indices whose first component does not
/// lie within the indicated `range` are not affected by the shift.
///
/// In order to implement `shift_in_range` correctly you have to increase the
/// left endpoint of `range` by 1 whenever you go recursively under a binder.
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self;
}

impl ShiftInRange for () {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> Self {}
}

impl ShiftInRange for Idx {
impl Shift for Idx {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
if range.contains(&self.fst) {
Self {
Expand All @@ -107,30 +125,61 @@ impl ShiftInRange for Idx {
}
}

impl<T: ShiftInRange> ShiftInRange for Rc<T> {
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn shift_fst() {
let result = Idx { fst: 0, snd: 0 }.shift((1, 0));
assert_eq!(result, Idx { fst: 1, snd: 0 });
}

#[test]
fn shift_snd() {
let result = Idx { fst: 0, snd: 0 }.shift((0, 1));
assert_eq!(result, Idx { fst: 0, snd: 1 });
}

#[test]
fn shift_in_range_fst() {
let result = Idx { fst: 0, snd: 0 }.shift_in_range(1.., (1, 0));
assert_eq!(result, Idx { fst: 0, snd: 0 });
}

#[test]
fn shift_in_range_snd() {
let result = Idx { fst: 0, snd: 0 }.shift_in_range(1.., (0, 1));
assert_eq!(result, Idx { fst: 0, snd: 0 });
}
}

pub trait ShiftRange: RangeBounds<usize> + Clone {}

impl<T: RangeBounds<usize> + Clone> ShiftRange for T {}

impl Shift for () {
fn shift_in_range<R: ShiftRange>(&self, _range: R, _by: (isize, isize)) -> Self {}
}

impl<T: Shift> Shift for Rc<T> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Rc::new((**self).shift_in_range(range, by))
}
}

impl<T: ShiftInRange> ShiftInRange for Option<T> {
impl<T: Shift> Shift for Option<T> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
self.as_ref().map(|inner| inner.shift_in_range(range, by))
}
}

impl<T: ShiftInRange> ShiftInRange for Vec<T> {
impl<T: Shift> Shift for Vec<T> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
self.iter().map(|x| x.shift_in_range(range.clone(), by)).collect()
}
}

impl<T: ShiftInRange> Shift for T {
fn shift(&self, by: (isize, isize)) -> Self {
self.shift_in_range(0.., by)
}
}

pub trait ShiftRangeExt {
type Target;

Expand Down
8 changes: 4 additions & 4 deletions lang/syntax/src/common/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use crate::ctx::*;

pub struct Assign<K, V>(pub K, pub V);

pub trait Substitution<E>: ShiftInRange {
pub trait Substitution<E>: Shift {
fn get_subst(&self, ctx: &LevelCtx, lvl: Lvl) -> Option<E>;
}

impl<E: Clone + ShiftInRange> Substitution<E> for Vec<Vec<E>> {
impl<E: Clone + Shift> Substitution<E> for Vec<Vec<E>> {
fn get_subst(&self, _ctx: &LevelCtx, lvl: Lvl) -> Option<E> {
if lvl.fst >= self.len() {
return None;
Expand All @@ -16,13 +16,13 @@ impl<E: Clone + ShiftInRange> Substitution<E> for Vec<Vec<E>> {
}
}

impl<K: Clone, V: ShiftInRange> ShiftInRange for Assign<K, V> {
impl<K: Clone, V: Shift> Shift for Assign<K, V> {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Assign(self.0.clone(), self.1.shift_in_range(range, by))
}
}

impl<E: Clone + ShiftInRange> Substitution<E> for Assign<Lvl, E> {
impl<E: Clone + Shift> Substitution<E> for Assign<Lvl, E> {
fn get_subst(&self, _ctx: &LevelCtx, lvl: Lvl) -> Option<E> {
if self.0 == lvl {
Some(self.1.clone())
Expand Down
2 changes: 1 addition & 1 deletion lang/syntax/src/ctx/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub struct Binder {
pub typ: Rc<ust::Exp>,
}

impl ShiftInRange for Binder {
impl Shift for Binder {
fn shift_in_range<R: ShiftRange>(&self, range: R, by: (isize, isize)) -> Self {
Self { name: self.name.clone(), typ: self.typ.shift_in_range(range, by) }
}
Expand Down
Loading