Skip to content

Commit

Permalink
Simplify conversion checking (#445)
Browse files Browse the repository at this point in the history
* Remove unificators from conversion checking

* Update README.md

* Remove guard which evaluates to true
  • Loading branch information
BinderDavid authored Jan 15, 2025
1 parent dc6cb8f commit a9a9c4f
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 125 deletions.
3 changes: 2 additions & 1 deletion lang/elaborator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ This crate contains the code of the elaborator who takes a program which hasn't
The elaborator consists of three main parts:

- The `normalizer` computes the normal form of expressions which is needed to check if two terms are convertible.
- The `unifier` solves a set of equality constraints and produces a most general unifier.
- The algorithm in `index_unification` is used during dependent pattern matching to unify the indices of a type with those in the return type of a specific constructor or destructor.
- The algorithm in `conversion_checking` checks whether two expressions are convertible and also solves metavariables while doing so.
- The `typechecker` traverses the untyped syntax tree, normalizes subexpressions, generates unification problems and outputs a fully elaborated syntax tree.

## Normalizer
Expand Down
119 changes: 10 additions & 109 deletions lang/elaborator/src/conversion_checking/unify.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,24 @@
use std::collections::HashSet;

use ast::ctx::LevelCtx;
use ast::{occurs_in, Variable};
use ast::Variable;
use codespan::Span;
use ctx::GenericCtx;

use crate::result::TypeError;
use ast::*;
use printer::{DocAllocator, Print};
use printer::Print;

use super::constraints::Constraint;
use super::dec::{Dec, No, Yes};

#[derive(Debug, Clone)]
pub struct Unificator {
map: HashMap<Lvl, Box<Exp>>,
}

impl Substitutable for Unificator {
type Result = Unificator;
fn subst<S: Substitution>(&self, ctx: &mut LevelCtx, by: &S) -> Self {
let map = self
.map
.iter()
.map(|(entry_lvl, entry_val)| (*entry_lvl, entry_val.subst(ctx, by)))
.collect();
Self { map }
}
}

impl Shift for Unificator {
fn shift_in_range<R: ShiftRange>(&mut self, range: &R, by: (isize, isize)) {
self.map.iter_mut().for_each(|(_, exp)| exp.shift_in_range(range, by));
}
}

impl Substitution for Unificator {
fn get_subst(&self, _ctx: &LevelCtx, lvl: Lvl) -> Option<Box<Exp>> {
self.map.get(&lvl).cloned()
}
}

impl Unificator {
pub fn empty() -> Self {
Self { map: HashMap::default() }
}
}

pub fn unify(
ctx: LevelCtx,
meta_vars: &mut HashMap<MetaVar, MetaVarState>,
constraint: Constraint,
vars_are_rigid: bool,
while_elaborating_span: &Option<Span>,
) -> Result<Dec<Unificator>, TypeError> {
let mut ctx = Ctx::new(vec![constraint], ctx.clone(), vars_are_rigid);
) -> Result<Dec<()>, TypeError> {
let mut ctx = Ctx::new(vec![constraint]);
let res = match ctx.unify(meta_vars, while_elaborating_span)? {
Yes(_) => Yes(ctx.unif),
Yes(_) => Yes(()),
No(()) => No(()),
};
Ok(res)
Expand All @@ -68,13 +30,6 @@ struct Ctx {
/// A cache of solved constraints. We can skip solving a constraint
/// if we have seen it before
done: HashSet<Constraint>,
ctx: LevelCtx,
/// Partial solution that we have computed from solving previous constraints.
unif: Unificator,
/// When we use the unifier as a conversion checker then we don't want to
/// treat two distinct variables as unifiable. In that case we call the unifier
/// and enable this boolean flag in order to treat all variables as rigid.
vars_are_rigid: bool,
}

/// Tests whether the hole is in Miller's pattern fragment, i.e. whether it is applied
Expand All @@ -99,14 +54,8 @@ fn is_solvable(h: &Hole) -> bool {
}

impl Ctx {
fn new(constraints: Vec<Constraint>, ctx: LevelCtx, vars_are_rigid: bool) -> Self {
Self {
constraints,
done: HashSet::default(),
ctx,
unif: Unificator::empty(),
vars_are_rigid,
}
fn new(constraints: Vec<Constraint>) -> Self {
Self { constraints, done: HashSet::default() }
}

fn unify(
Expand Down Expand Up @@ -134,7 +83,7 @@ impl Ctx {
) -> Result<Dec, TypeError> {
match eqn {
Constraint::Equality { lhs, rhs, .. } => match (&**lhs, &**rhs) {
(Exp::Hole(h), e) | (e, Exp::Hole(h)) if self.vars_are_rigid => {
(Exp::Hole(h), e) | (e, Exp::Hole(h)) => {
let metavar_state = meta_vars.get(&h.metavar).unwrap();
match metavar_state {
MetaVarState::Solved { ctx, solution } => {
Expand Down Expand Up @@ -170,26 +119,12 @@ impl Ctx {
) => {
if idx_1 == idx_2 {
Ok(Yes(()))
} else if self.vars_are_rigid {
Ok(No(()))
} else {
self.add_assignment(*idx_1, rhs.clone())
}
}
(Exp::Variable(Variable { idx, .. }), _) => {
if self.vars_are_rigid {
Ok(No(()))
} else {
self.add_assignment(*idx, rhs.clone())
}
}
(_, Exp::Variable(Variable { idx, .. })) => {
if self.vars_are_rigid {
Ok(No(()))
} else {
self.add_assignment(*idx, lhs.clone())
}
}
(Exp::Variable(Variable { .. }), _) => Ok(No(())),
(_, Exp::Variable(Variable { .. })) => Ok(No(())),
(
Exp::TypCtor(TypCtor { name, args, .. }),
Exp::TypCtor(TypCtor { name: name2, args: args2, .. }),
Expand Down Expand Up @@ -264,25 +199,6 @@ impl Ctx {
}
}

fn add_assignment(&mut self, idx: Idx, exp: Box<Exp>) -> Result<Dec, TypeError> {
if occurs_in(&mut self.ctx, idx, &exp) {
return Err(TypeError::occurs_check_failed(idx, &exp));
}
let insert_lvl = self.ctx.idx_to_lvl(idx);
let exp = exp.subst(&mut self.ctx, &self.unif);
self.unif = self.unif.subst(&mut self.ctx, &Assign { lvl: insert_lvl, exp: exp.clone() });
match self.unif.map.get(&insert_lvl) {
Some(other_exp) => {
let eqn = Constraint::Equality { lhs: exp, rhs: other_exp.clone() };
self.add_constraint(eqn)
}
None => {
self.unif.map.insert(insert_lvl, exp);
Ok(Yes(()))
}
}
}

fn add_constraint(&mut self, eqn: Constraint) -> Result<Dec, TypeError> {
self.add_constraints([eqn])
}
Expand Down Expand Up @@ -329,21 +245,6 @@ impl Ctx {
}
}

impl Print for Unificator {
fn print<'a>(
&'a self,
cfg: &printer::PrintCfg,
alloc: &'a printer::Alloc<'a>,
) -> printer::Builder<'a> {
let mut keys: Vec<_> = self.map.keys().collect();
keys.sort();
let exps = keys.into_iter().map(|key| {
alloc.text(format!("{key}")).append(" := ").append(self.map[key].print(cfg, alloc))
});
alloc.intersperse(exps, ",").enclose("{", "}")
}
}

fn zip_cases_by_xtors(
cases_lhs: &[Case],
cases_rhs: &[Case],
Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/anno.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl CheckInfer for Anno {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(ctx.levels(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl CheckInfer for Call {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(ctx.levels(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}
/// The *inference* rule for calls is:
Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/dot_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl CheckInfer for DotCall {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(ctx.levels(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/local_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl CheckInfer for LocalMatch {
let mut motive_t = ret_typ.subst(&mut subst_ctx, &subst);
motive_t.shift((-1, 0));
let motive_t_nf = motive_t.normalize(&ctx.type_info_table, &mut ctx.env())?;
convert(subst_ctx, &mut ctx.meta_vars, motive_t_nf, t, span)?;
convert(&mut ctx.meta_vars, motive_t_nf, t, span)?;

body_t = ctx.bind_single(&self_binder, |ctx| {
ret_typ.normalize(&ctx.type_info_table, &mut ctx.env())
Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/typ_ctor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl CheckInfer for TypCtor {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(ctx.levels(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down
8 changes: 1 addition & 7 deletions lang/elaborator/src/typechecker/exprs/type_univ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@ impl CheckInfer for TypeUniv {
/// P, Γ ⊢ Type ⇐ τ
/// ```
fn check(&self, ctx: &mut Ctx, t: &Exp) -> Result<Self, TypeError> {
convert(
ctx.levels(),
&mut ctx.meta_vars,
Box::new(TypeUniv::new().into()),
t,
&self.span(),
)?;
convert(&mut ctx.meta_vars, Box::new(TypeUniv::new().into()), t, &self.span())?;
Ok(self.clone())
}

Expand Down
2 changes: 1 addition & 1 deletion lang/elaborator/src/typechecker/exprs/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl CheckInfer for Variable {
message: "Expected inferred type".to_owned(),
span: None,
})?;
convert(ctx.levels(), &mut ctx.meta_vars, inferred_typ, t, &self.span())?;
convert(&mut ctx.meta_vars, inferred_typ, t, &self.span())?;
Ok(inferred_term)
}

Expand Down
3 changes: 1 addition & 2 deletions lang/elaborator/src/typechecker/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub fn uses_self(codata: &Codata) -> Result<bool, TypeError> {
}

pub fn convert(
ctx: LevelCtx,
meta_vars: &mut HashMap<MetaVar, MetaVarState>,
this: Box<Exp>,
other: &Exp,
Expand All @@ -31,7 +30,7 @@ pub fn convert(
// Convertibility is checked using the unification algorithm.
let constraint: Constraint =
Constraint::Equality { lhs: this.clone(), rhs: Box::new(other.clone()) };
let res = unify(ctx, meta_vars, constraint, true, while_elaborating_span)?;
let res = unify(meta_vars, constraint, while_elaborating_span)?;
match res {
crate::conversion_checking::dec::Dec::Yes(_) => Ok(()),
crate::conversion_checking::dec::Dec::No(_) => {
Expand Down

0 comments on commit a9a9c4f

Please sign in to comment.