Skip to content

Commit

Permalink
Add TypeCtx to conversion checking constraints (#450)
Browse files Browse the repository at this point in the history
* Add TypeCtx to conversion checking constraints

* Fix contexts in test doc comments
  • Loading branch information
timsueberkrueb authored Jan 16, 2025
1 parent 2e37373 commit ab3533a
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 45 deletions.
44 changes: 32 additions & 12 deletions lang/elaborator/src/conversion_checking/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
//! This module defines the language of constraints that can be solved by the constraint solver.
use ast::{Args, Exp};
use ast::{ctx::values::TypeCtx, Args, Exp};
use derivative::Derivative;
use printer::Print;

/// A constraint that can be solved by the constraint solver.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
#[derive(Debug, Clone, Derivative)]
#[derivative(Eq, PartialEq, Hash)]
pub enum Constraint {
/// An equality constraint between two expressions.
Equality { lhs: Box<Exp>, rhs: Box<Exp> },
/// An equality constraint between two argument lists.
EqualityArgs { lhs: Args, rhs: Args },
/// An equality constraint between two expressions under the same context.
/// ctx |- lhs = rhs
Equality {
#[derivative(PartialEq = "ignore", Hash = "ignore")]
ctx: TypeCtx,
lhs: Box<Exp>,
rhs: Box<Exp>,
},
/// An equality constraint between two argument lists under the same context.
/// ctx |- lhs = rhs
EqualityArgs {
#[derivative(PartialEq = "ignore", Hash = "ignore")]
ctx: TypeCtx,
lhs: Args,
rhs: Args,
},
}

impl Print for Constraint {
Expand All @@ -18,12 +32,18 @@ impl Print for Constraint {
alloc: &'a printer::Alloc<'a>,
) -> printer::Builder<'a> {
match self {
Constraint::Equality { lhs, rhs } => {
lhs.print(cfg, alloc).append(" = ").append(rhs.print(cfg, alloc))
}
Constraint::EqualityArgs { lhs, rhs } => {
lhs.print(cfg, alloc).append(" = ").append(rhs.print(cfg, alloc))
}
Constraint::Equality { ctx, lhs, rhs } => ctx
.print(cfg, alloc)
.append(" |- ")
.append(lhs.print(cfg, alloc))
.append(" = ")
.append(rhs.print(cfg, alloc)),
Constraint::EqualityArgs { ctx, lhs, rhs } => ctx
.print(cfg, alloc)
.append(" |- ")
.append(lhs.print(cfg, alloc))
.append(" = ")
.append(rhs.print(cfg, alloc)),
}
}
}
87 changes: 71 additions & 16 deletions lang/elaborator/src/conversion_checking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
//! ```
//! When hovering over the holes in an editor connected to our language server, you will see that both holes are solved with `Nat`.
use ast::{Exp, HashMap, MetaVar, MetaVarState};
use ast::{ctx::values::TypeCtx, Exp, HashMap, MetaVar, MetaVarState};
use codespan::Span;
use constraints::Constraint;
use dec::Dec;
Expand All @@ -67,15 +67,16 @@ mod dec;
mod unify;

pub fn convert(
ctx: TypeCtx,
meta_vars: &mut HashMap<MetaVar, MetaVarState>,
this: Box<Exp>,
other: &Exp,
while_elaborating_span: &Option<Span>,
) -> Result<(), TypeError> {
trace!("{} =? {}", this.print_trace(), other.print_trace());
trace!("{} |- {} =? {}", ctx.print_trace(), this.print_trace(), other.print_trace());
// Convertibility is checked using the unification algorithm.
let constraint: Constraint =
Constraint::Equality { lhs: this.clone(), rhs: Box::new(other.clone()) };
Constraint::Equality { ctx, lhs: this.clone(), rhs: Box::new(other.clone()) };
let mut ctx = Ctx::new(vec![constraint]);
match ctx.unify(meta_vars, while_elaborating_span)? {
Dec::Yes => Ok(()),
Expand All @@ -85,31 +86,34 @@ pub fn convert(

#[cfg(test)]
mod test {
use ast::{HashMap, Idx, MetaVar, MetaVarState, TypeUniv, VarBound, Variable};
use ast::{
ctx::values::{Binder, TypeCtx},
HashMap, Idx, MetaVar, MetaVarState, TypeUniv, VarBind, VarBound, Variable,
};

use crate::conversion_checking::{constraints::Constraint, dec::Dec, unify::Ctx};

/// Assert that the two expressions are convertible
fn check_eq<E: Into<ast::Exp>>(e1: E, e2: E) {
fn check_eq<E: Into<ast::Exp>>(ctx: TypeCtx, e1: E, e2: E) {
let constraint =
Constraint::Equality { lhs: Box::new(e1.into()), rhs: Box::new(e2.into()) };
Constraint::Equality { ctx, lhs: Box::new(e1.into()), rhs: Box::new(e2.into()) };

let mut ctx = Ctx::new(vec![constraint]);
let mut hm: HashMap<MetaVar, MetaVarState> = Default::default();
assert!(ctx.unify(&mut hm, &None).unwrap() == Dec::Yes)
}

/// Assert that the two expressions are not convertible
fn check_neq<E: Into<ast::Exp>>(e1: E, e2: E) {
fn check_neq<E: Into<ast::Exp>>(ctx: TypeCtx, e1: E, e2: E) {
let constraint =
Constraint::Equality { lhs: Box::new(e1.into()), rhs: Box::new(e2.into()) };
Constraint::Equality { ctx, lhs: Box::new(e1.into()), rhs: Box::new(e2.into()) };

let mut ctx = Ctx::new(vec![constraint]);
let mut hm: HashMap<MetaVar, MetaVarState> = Default::default();
assert!(ctx.unify(&mut hm, &None).unwrap() == Dec::No)
}

/// Check that `v =? v` holds.
/// Check that `[[a: Type, v: a]] |- v =? v` holds.
#[test]
fn convert_var_var_1() {
let v = Variable {
Expand All @@ -118,32 +122,83 @@ mod test {
name: VarBound { span: None, id: "x".to_string() },
inferred_type: None,
};
check_eq(v.clone(), v)
let ctx = vec![vec![
Binder {
name: VarBind { span: None, id: "a".to_string() },
typ: Box::new(TypeUniv { span: None }.into()),
},
Binder {
name: VarBind { span: None, id: "v".to_string() },
typ: Box::new(
Variable {
span: None,
idx: Idx { fst: 0, snd: 1 },
name: VarBound { span: None, id: "a".to_string() },
inferred_type: None,
}
.into(),
),
},
]];
check_eq(ctx.into(), v.clone(), v)
}

/// Check that `v =? v'` does not hold.
/// Check that `[[a: Type, v', v]] |- v =? v'` does not hold.
#[test]
fn convert_var_var_2() {
let v1 = Variable {
span: None,
idx: Idx { fst: 0, snd: 0 },
name: VarBound { span: None, id: "x".to_string() },
name: VarBound { span: None, id: "v".to_string() },
inferred_type: None,
};

let v2 = Variable {
span: None,
idx: Idx { fst: 1, snd: 0 },
name: VarBound { span: None, id: "x".to_string() },
name: VarBound { span: None, id: "v'".to_string() },
inferred_type: None,
};
check_neq(v1, v2);

let ctx = vec![vec![
Binder {
name: VarBind { span: None, id: "a".to_string() },
typ: Box::new(TypeUniv { span: None }.into()),
},
Binder {
name: VarBind { span: None, id: "v'".to_string() },
typ: Box::new(
Variable {
span: None,
idx: Idx { fst: 0, snd: 2 },
name: VarBound { span: None, id: "a".to_string() },
inferred_type: None,
}
.into(),
),
},
Binder {
name: VarBind { span: None, id: "v".to_string() },
typ: Box::new(
Variable {
span: None,
idx: Idx { fst: 0, snd: 2 },
name: VarBound { span: None, id: "a".to_string() },
inferred_type: None,
}
.into(),
),
},
]];

check_neq(ctx.into(), v1, v2);
}

/// Check that `Type =? Type` holds.
/// Check that `[] |- Type =? Type` holds.
#[test]
fn convert_type_type() {
let t = TypeUniv { span: None };
check_eq(t.clone(), t);
let ctx = vec![];
check_eq(ctx.into(), t.clone(), t);
}
}
37 changes: 27 additions & 10 deletions lang/elaborator/src/conversion_checking/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ impl Ctx {
while_elaborating_span: &Option<Span>,
) -> Result<Dec, TypeError> {
match eqn {
Constraint::Equality { lhs, rhs, .. } => match (&**lhs, &**rhs) {
Constraint::Equality { ctx: constraint_cxt, lhs, rhs } => match (&**lhs, &**rhs) {
(Exp::Hole(h), e) | (e, Exp::Hole(h)) => {
let metavar_state = meta_vars.get(&h.metavar).unwrap();
match metavar_state {
MetaVarState::Solved { ctx, solution } => {
let lhs = solution.clone().subst(&mut ctx.clone(), &h.args);
self.add_constraint(Constraint::Equality {
ctx: constraint_cxt.clone(),
lhs,
rhs: Box::new(e.clone()),
})?;
Expand Down Expand Up @@ -116,8 +117,11 @@ impl Ctx {
Exp::TypCtor(TypCtor { name, args, .. }),
Exp::TypCtor(TypCtor { name: name2, args: args2, .. }),
) if name == name2 => {
let constraint =
Constraint::EqualityArgs { lhs: args.clone(), rhs: args2.clone() };
let constraint = Constraint::EqualityArgs {
ctx: constraint_cxt.clone(),
lhs: args.clone(),
rhs: args2.clone(),
};
self.add_constraint(constraint)
}
(Exp::TypCtor(TypCtor { name, .. }), Exp::TypCtor(TypCtor { name: name2, .. }))
Expand All @@ -129,8 +133,11 @@ impl Ctx {
Exp::Call(Call { name, args, .. }),
Exp::Call(Call { name: name2, args: args2, .. }),
) if name == name2 => {
let constraint =
Constraint::EqualityArgs { lhs: args.clone(), rhs: args2.clone() };
let constraint = Constraint::EqualityArgs {
ctx: constraint_cxt.clone(),
lhs: args.clone(),
rhs: args2.clone(),
};
self.add_constraint(constraint)
}
(Exp::Call(Call { name, .. }), Exp::Call(Call { name: name2, .. }))
Expand All @@ -143,19 +150,25 @@ impl Ctx {
Exp::DotCall(DotCall { exp: exp2, name: name2, args: args2, .. }),
) if name == name2 => {
self.add_constraint(Constraint::Equality {
ctx: constraint_cxt.clone(),
lhs: exp.clone(),
rhs: exp2.clone(),
})?;
let constraint =
Constraint::EqualityArgs { lhs: args.clone(), rhs: args2.clone() };
let constraint = Constraint::EqualityArgs {
ctx: constraint_cxt.clone(),
lhs: args.clone(),
rhs: args2.clone(),
};
self.add_constraint(constraint)
}
(Exp::TypeUniv(_), Exp::TypeUniv(_)) => Ok(Yes),
(Exp::Anno(Anno { exp, .. }), rhs) => self.add_constraint(Constraint::Equality {
ctx: constraint_cxt.clone(),
lhs: exp.clone(),
rhs: Box::new(rhs.clone()),
}),
(lhs, Exp::Anno(Anno { exp, .. })) => self.add_constraint(Constraint::Equality {
ctx: constraint_cxt.clone(),
lhs: Box::new(lhs.clone()),
rhs: exp.clone(),
}),
Expand All @@ -166,7 +179,7 @@ impl Ctx {
let new_eqns =
zip_cases_by_xtors(cases_lhs, cases_rhs).filter_map(|(lhs, rhs)| {
if let (Some(lhs), Some(rhs)) = (lhs.body, rhs.body) {
Some(Constraint::Equality { lhs, rhs })
Some(Constraint::Equality { ctx: constraint_cxt.clone(), lhs, rhs })
} else {
None
}
Expand All @@ -175,10 +188,14 @@ impl Ctx {
}
(_, _) => Err(TypeError::cannot_decide(lhs, rhs, while_elaborating_span)),
},
Constraint::EqualityArgs { lhs, rhs } => {
Constraint::EqualityArgs { ctx: constraint_ctx, lhs, rhs } => {
let new_eqns =
lhs.args.iter().cloned().zip(rhs.args.iter().cloned()).map(|(lhs, rhs)| {
Constraint::Equality { lhs: lhs.exp().clone(), rhs: rhs.exp().clone() }
Constraint::Equality {
ctx: constraint_ctx.clone(),
lhs: lhs.exp().clone(),
rhs: rhs.exp().clone(),
}
});
self.add_constraints(new_eqns)?;
Ok(Yes)
Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/anno.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl CheckInfer for Anno {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(ctx.vars.clone(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl CheckInfer for Call {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(ctx.vars.clone(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}
/// The *inference* rule for calls is:
Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/dot_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl CheckInfer for DotCall {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(ctx.vars.clone(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/local_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl CheckInfer for LocalMatch {
let mut motive_t = ret_typ.subst(&mut subst_ctx, &subst);
motive_t.shift((-1, 0));
let motive_t_nf = motive_t.normalize(&ctx.type_info_table, &mut ctx.env())?;
convert(&mut ctx.meta_vars, motive_t_nf, t, span)?;
convert(ctx.vars.clone(), &mut ctx.meta_vars, motive_t_nf, t, span)?;

body_t = ctx.bind_single(&self_binder, |ctx| {
ret_typ.normalize(&ctx.type_info_table, &mut ctx.env())
Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/typ_ctor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl CheckInfer for TypCtor {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(ctx.vars.clone(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down
8 changes: 7 additions & 1 deletion lang/elaborator/src/typechecker/exprs/type_univ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ impl CheckInfer for TypeUniv {
/// P, Γ ⊢ Type ⇐ τ
/// ```
fn check(&self, ctx: &mut Ctx, t: &Exp) -> Result<Self, TypeError> {
convert(&mut ctx.meta_vars, Box::new(TypeUniv::new().into()), t, &self.span())?;
convert(
ctx.vars.clone(),
&mut ctx.meta_vars,
Box::new(TypeUniv::new().into()),
t,
&self.span(),
)?;
Ok(self.clone())
}

Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl CheckInfer for Variable {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(ctx.vars.clone(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down

0 comments on commit ab3533a

Please sign in to comment.