Skip to content

Commit

Permalink
[sc-301] Add pattern matching on non-zero numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
developedby committed Feb 27, 2024
1 parent 28d456d commit 39ffe05
Show file tree
Hide file tree
Showing 99 changed files with 627 additions and 234 deletions.
7 changes: 2 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,8 @@ pub fn create_host(book: Arc<Book>, labels: Arc<Labels>, compile_opts: CompileOp
term.resugar_builtins();

readback_errors.extend(resugar_errs);
match term {
Term::Str { val } => {
println!("{}", val);
}
_ => (),
if let Term::Str { val } = term {
println!("{val}");
}
}
}))),
Expand Down
1 change: 0 additions & 1 deletion src/term/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ impl Pattern {
snd.encode_builtins();
}
Pattern::Var(..) | Pattern::Num(..) => {}
Pattern::Err => unreachable!(),
}
}

Expand Down
1 change: 0 additions & 1 deletion src/term/check/ctrs_arities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl Pattern {
}
}
Pattern::Var(..) | Pattern::Num(..) => {}
Pattern::Err => unreachable!(),
}
}
Ok(())
Expand Down
14 changes: 13 additions & 1 deletion src/term/check/type_check.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::term::{transform::encode_pattern_matching::MatchErr, Constructors, Name, Pattern, Type};
use crate::term::{transform::encode_pattern_matching::MatchErr, Constructors, Name, Pattern, Rule, Type};
use indexmap::IndexMap;

pub type DefinitionTypes = IndexMap<Name, Vec<Type>>;

pub fn infer_match_arg_type(rules: &[Rule], arg_idx: usize, ctrs: &Constructors) -> Result<Type, MatchErr> {
infer_type(rules.iter().map(|r| &r.pats[arg_idx]), ctrs)
}

/// Infers the type of a sequence of arguments
pub fn infer_type<'a>(
pats: impl IntoIterator<Item = &'a Pattern>,
Expand All @@ -19,9 +23,17 @@ fn unify(old: Type, new: Type) -> Result<Type, MatchErr> {
match (old, new) {
(Type::Any, new) => Ok(new),
(old, Type::Any) => Ok(old),

(Type::Adt(old), Type::Adt(new)) if new == old => Ok(Type::Adt(old)),

(Type::Num, Type::Num) => Ok(Type::Num),
(Type::Num, Type::NumSucc(n)) => Ok(Type::NumSucc(n)),

(Type::NumSucc(n), Type::Num) => Ok(Type::NumSucc(n)),
(Type::NumSucc(a), Type::NumSucc(b)) if a == b => Ok(Type::NumSucc(a)),

(Type::Tup, Type::Tup) => Ok(Type::Tup),

(old, new) => Err(MatchErr::TypeMismatch(new, old)),
}
}
1 change: 0 additions & 1 deletion src/term/check/unbound_pats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl Pattern {
}
Pattern::Lst(args) => args.iter().for_each(|arg| check.push(arg)),
Pattern::Var(_) | Pattern::Num(_) => {}
Pattern::Err => unreachable!(),
}
}
unbounds
Expand Down
16 changes: 7 additions & 9 deletions src/term/display.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use super::{
net_to_term::ReadbackError, Book, Definition, MatchNum, Name, Op, Pattern, Rule, Tag, Term, Type,
};
use super::{net_to_term::ReadbackError, Book, Definition, Name, NumCtr, Op, Pattern, Rule, Tag, Term, Type};
use std::{fmt, ops::Deref};

/* Some aux structures for things that are not so simple to display */
Expand Down Expand Up @@ -113,7 +111,6 @@ impl fmt::Display for Pattern {
Pattern::Num(num) => write!(f, "{num}"),
Pattern::Tup(fst, snd) => write!(f, "({}, {})", fst, snd,),
Pattern::Lst(pats) => write!(f, "[{}]", DisplayJoin(|| pats, ", ")),
Pattern::Err => write!(f, "<Invalid>"),
}
}
}
Expand Down Expand Up @@ -141,13 +138,13 @@ impl fmt::Display for Book {
}
}

impl fmt::Display for MatchNum {
impl fmt::Display for NumCtr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MatchNum::Zero => write!(f, "0"),
MatchNum::Succ(None) => write!(f, "+"),
MatchNum::Succ(Some(None)) => write!(f, "+*"),
MatchNum::Succ(Some(Some(nam))) => write!(f, "+{nam}"),
NumCtr::Num(n) => write!(f, "{n}"),
NumCtr::Succ(n, None) => write!(f, "{n}+"),
NumCtr::Succ(n, Some(None)) => write!(f, "{n}+*"),
NumCtr::Succ(n, Some(Some(nam))) => write!(f, "{n}+{nam}"),
}
}
}
Expand Down Expand Up @@ -188,6 +185,7 @@ impl fmt::Display for Type {
Type::Any => write!(f, "any"),
Type::Tup => write!(f, "tup"),
Type::Num => write!(f, "num"),
Type::NumSucc(n) => write!(f, "{n}+"),
Type::Adt(nam) => write!(f, "{nam}"),
}
}
Expand Down
136 changes: 92 additions & 44 deletions src/term/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use self::{check::type_check::infer_type, parser::lexer::STRINGS};
use self::{check::type_check::infer_match_arg_type, parser::lexer::STRINGS};
use crate::{diagnostics::Info, term::builtins::*, ENTRY_POINT};
use indexmap::{IndexMap, IndexSet};
use interner::global::GlobalString;
use itertools::Itertools;
use std::{borrow::Cow, collections::HashMap, ops::Deref};
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
ops::Deref,
};

pub mod builtins;
pub mod check;
Expand Down Expand Up @@ -136,21 +140,19 @@ pub enum Term {
Err,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Pattern {
Var(Option<Name>),
Ctr(Name, Vec<Pattern>),
Num(MatchNum),
Num(NumCtr),
Tup(Box<Pattern>, Box<Pattern>),
Lst(Vec<Pattern>),
#[default]
Err,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MatchNum {
Zero,
Succ(Option<Option<Name>>),
pub enum NumCtr {
Num(u64),
Succ(u64, Option<Option<Name>>),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -185,9 +187,15 @@ pub enum Op {
/// Pattern types.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Type {
/// Variables/wildcards.
Any,
/// A native tuple.
Tup,
/// A sequence of arbitrary numbers ending in a variable.
Num,
/// A strictly incrementing sequence of numbers starting from 0, ending in a + ctr.
NumSucc(u64),
/// Adt constructors declared with the `data` syntax.
Adt(Name),
}

Expand Down Expand Up @@ -239,6 +247,7 @@ impl Tag {
}

impl Term {
/* Common construction patterns */
pub fn lam(nam: Option<Name>, bod: Term) -> Self {
Term::Lam { tag: Tag::Static, nam, bod: Box::new(bod) }
}
Expand Down Expand Up @@ -289,6 +298,30 @@ impl Term {
Term::Str { val: STRINGS.get(str) }
}

pub fn native_num_match(arg: Term, zero: Term, succ: Term) -> Term {
let zero = Rule { pats: vec![Pattern::Num(NumCtr::Num(0))], body: zero };
let succ = Rule { pats: vec![Pattern::Num(NumCtr::Succ(1, None))], body: succ };
Term::Mat { args: vec![arg], rules: vec![zero, succ] }
}

pub fn sub_num(arg: Term, val: u64) -> Term {
if val == 0 {
arg
} else {
Term::Opx { op: Op::SUB, fst: Box::new(arg), snd: Box::new(Term::Num { val }) }
}
}

pub fn add_num(arg: Term, val: u64) -> Term {
if val == 0 {
arg
} else {
Term::Opx { op: Op::ADD, fst: Box::new(arg), snd: Box::new(Term::Num { val }) }
}
}

/* Common checks and transformations */

/// Substitute the occurrences of a variable in a term with the given term.
/// Caution: can cause invalid shadowing of variables if used incorrectly.
/// Ex: Using subst to beta-reduce (@a @b a b) converting it into (@b b).
Expand Down Expand Up @@ -507,7 +540,7 @@ impl Term {
return false;
}
// The match is over a valid type
let Ok(typ) = infer_type(rules.iter().map(|r| &r.pats[0]), ctrs) else {
let Ok(typ) = infer_match_arg_type(rules, 0, ctrs) else {
return false;
};
// The match has one arm for each constructor, matching the constructors in adt declaration order
Expand All @@ -529,14 +562,32 @@ impl Term {
}
}
Type::Num => {
if rules.len() != 2 {
return false;
let mut nums = HashSet::new();
for rule in rules {
if let Pattern::Num(NumCtr::Num(n)) = &rule.pats[0] {
if nums.contains(n) {
return false;
}
nums.insert(*n);
}
}
if !matches!(rules[0].pats[0], Pattern::Num(MatchNum::Zero)) {
}
Type::NumSucc(n) => {
if rules.len() as u64 != n + 1 {
return false;
}
if !matches!(rules[1].pats[0], Pattern::Num(MatchNum::Succ(Some(_)))) {
return false;
for (i, _) in rules.iter().enumerate() {
if i as u64 == n {
let Pattern::Num(NumCtr::Succ(n_pat, Some(_))) = &rules[i].pats[0] else { return false };
if n != *n_pat {
return false;
}
} else {
let Pattern::Num(NumCtr::Num(i_pat)) = &rules[i].pats[0] else { return false };
if i as u64 != *i_pat {
return false;
}
}
}
}
Type::Adt(adt) => {
Expand Down Expand Up @@ -570,7 +621,7 @@ impl Pattern {
pub fn bind_or_eras(&self) -> impl DoubleEndedIterator<Item = &Option<Name>> {
self.iter().filter_map(|pat| match pat {
Pattern::Var(nam) => Some(nam),
Pattern::Num(MatchNum::Succ(nam)) => nam.as_ref(),
Pattern::Num(NumCtr::Succ(_, nam)) => nam.as_ref(),
_ => None,
})
}
Expand All @@ -585,11 +636,10 @@ impl Pattern {
go(fst, set);
go(snd, set);
}
Pattern::Num(MatchNum::Succ(Some(nam))) => {
Pattern::Num(NumCtr::Succ(_, Some(nam))) => {
set.push(nam);
}
Pattern::Num(_) => {}
Pattern::Err => unreachable!(),
}
}
let mut set = Vec::new();
Expand All @@ -614,7 +664,6 @@ impl Pattern {
Pattern::Num(_) => Box::new([].iter()),
Pattern::Tup(fst, snd) => Box::new([fst.as_ref(), snd.as_ref()].into_iter()),
Pattern::Lst(els) => Box::new(els.iter()),
Pattern::Err => unreachable!(),
}
}

Expand All @@ -634,24 +683,24 @@ impl Pattern {
match self {
Pattern::Var(_) => None,
Pattern::Ctr(nam, _) => Some(nam.clone()),
Pattern::Num(MatchNum::Zero) => Some(Name::new("0")),
Pattern::Num(MatchNum::Succ(_)) => Some(Name::new("+")),
Pattern::Num(NumCtr::Num(num)) => Some(Name::new(format!("{num}"))),
Pattern::Num(NumCtr::Succ(num, _)) => Some(Name::new(format!("{num}+"))),
Pattern::Tup(_, _) => Some(Name::new("(,)")),
Pattern::Lst(_) => todo!(),
Pattern::Err => unreachable!(),
}
}

pub fn is_wildcard(&self) -> bool {
matches!(self, Pattern::Var(_))
}

pub fn is_detached_num_match(&self) -> bool {
if let Pattern::Num(num) = self {
match num {
MatchNum::Zero => true,
MatchNum::Succ(None) => true,
MatchNum::Succ(Some(_)) => false,
pub fn is_native_num_match(&self) -> bool {
if let Pattern::Num(ctr) = self {
match ctr {
NumCtr::Num(0) => true,
NumCtr::Num(_) => false,
NumCtr::Succ(1, None) => true,
NumCtr::Succ(_, _) => false,
}
} else {
false
Expand All @@ -667,7 +716,6 @@ impl Pattern {
Pattern::Tup(fst, snd) => {
matches!(fst.as_ref(), Pattern::Var(_)) && matches!(snd.as_ref(), Pattern::Var(_))
}
Pattern::Err => unreachable!(),
}
}

Expand All @@ -679,9 +727,9 @@ impl Pattern {
Type::Adt(adt_nam.clone())
}
Pattern::Tup(..) => Type::Tup,
Pattern::Num(..) => Type::Num,
Pattern::Num(NumCtr::Num(_)) => Type::Num,
Pattern::Num(NumCtr::Succ(n, _)) => Type::NumSucc(*n),
Pattern::Lst(..) => Type::Adt(builtins::LIST.into()),
Pattern::Err => unreachable!(),
}
}

Expand All @@ -691,31 +739,28 @@ impl Pattern {
Pattern::Ctr(ctr, args) => {
Term::call(Term::Ref { nam: ctr.clone() }, args.iter().map(|arg| arg.to_term()))
}
Pattern::Num(MatchNum::Zero) => Term::Num { val: 0 },
Pattern::Num(NumCtr::Num(val)) => Term::Num { val: *val },
// Succ constructor with no variable is not a valid term, only a compiler intermediate for a MAT inet node.
Pattern::Num(MatchNum::Succ(None)) => unreachable!(),
Pattern::Num(MatchNum::Succ(Some(Some(nam)))) => Term::Opx {
op: Op::ADD,
fst: Box::new(Term::Var { nam: nam.clone() }),
snd: Box::new(Term::Num { val: 1 }),
},
Pattern::Num(MatchNum::Succ(Some(None))) => Term::Era,
Pattern::Num(NumCtr::Succ(_, None)) => unreachable!(),
Pattern::Num(NumCtr::Succ(val, Some(Some(nam)))) => Term::add_num(Term::Var { nam: nam.clone() }, *val),
Pattern::Num(NumCtr::Succ(_, Some(None))) => Term::Era,
Pattern::Tup(fst, snd) => Term::Tup { fst: Box::new(fst.to_term()), snd: Box::new(snd.to_term()) },
Pattern::Lst(_) => {
let mut p = self.clone();
p.encode_builtins();
p.to_term()
}
Pattern::Err => unreachable!(),
}
}

/// True if both patterns are equal (match the same expressions) without considering nested patterns.
pub fn simple_equals(&self, other: &Pattern) -> bool {
match (self, other) {
(Pattern::Ctr(a, _), Pattern::Ctr(b, _)) if a == b => true,
(Pattern::Num(MatchNum::Zero), Pattern::Num(MatchNum::Zero)) => true,
(Pattern::Num(MatchNum::Succ(_)), Pattern::Num(MatchNum::Succ(_))) => true,
(Pattern::Num(NumCtr::Num(a)), Pattern::Num(NumCtr::Num(b))) if a == b => true,
(Pattern::Num(NumCtr::Num(_)), Pattern::Num(NumCtr::Num(_))) => false,
(Pattern::Num(NumCtr::Succ(a, _)), Pattern::Num(NumCtr::Succ(b, _))) if a == b => true,
(Pattern::Num(NumCtr::Succ(_, _)), Pattern::Num(NumCtr::Succ(_, _))) => false,
(Pattern::Tup(_, _), Pattern::Tup(_, _)) => true,
(Pattern::Lst(_), Pattern::Lst(_)) => true,
(Pattern::Var(_), Pattern::Var(_)) => true,
Expand Down Expand Up @@ -774,9 +819,12 @@ impl Type {
Box::new(Pattern::Var(Some("%fst".into()))),
Box::new(Pattern::Var(Some("%snd".into()))),
)],
Type::Num => {
vec![Pattern::Num(MatchNum::Zero), Pattern::Num(MatchNum::Succ(Some(Some("%pred".into()))))]
Type::NumSucc(n) => {
let mut ctrs = (0 .. *n).map(|n| Pattern::Num(NumCtr::Num(n))).collect::<Vec<_>>();
ctrs.push(Pattern::Num(NumCtr::Succ(*n, Some(Some("%pred".into())))));
ctrs
}
Type::Num => unreachable!(),
Type::Adt(adt) => {
// TODO: Should return just a ref to ctrs and not clone.
adts[adt]
Expand Down
Loading

0 comments on commit 39ffe05

Please sign in to comment.