Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sc-301] Add support for matching on non zero numbers, add char patterns #211

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,8 @@ pub fn create_host(book: Arc<Book>, labels: Arc<Labels>, compile_opts: CompileOp
term.resugar_builtins();

readback_errors.extend(resugar_errs);
match term {
Term::Str { val } => {
println!("{}", val);
}
_ => (),
if let Term::Str { val } = term {
println!("{val}");
}
}
}))),
Expand Down
1 change: 0 additions & 1 deletion src/term/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ impl Pattern {
snd.encode_builtins();
}
Pattern::Var(..) | Pattern::Num(..) => {}
Pattern::Err => unreachable!(),
}
}

Expand Down
1 change: 0 additions & 1 deletion src/term/check/ctrs_arities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl Pattern {
}
}
Pattern::Var(..) | Pattern::Num(..) => {}
Pattern::Err => unreachable!(),
}
}
Ok(())
Expand Down
14 changes: 13 additions & 1 deletion src/term/check/type_check.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::term::{transform::encode_pattern_matching::MatchErr, Constructors, Name, Pattern, Type};
use crate::term::{transform::encode_pattern_matching::MatchErr, Constructors, Name, Pattern, Rule, Type};
use indexmap::IndexMap;

pub type DefinitionTypes = IndexMap<Name, Vec<Type>>;

pub fn infer_match_arg_type(rules: &[Rule], arg_idx: usize, ctrs: &Constructors) -> Result<Type, MatchErr> {
infer_type(rules.iter().map(|r| &r.pats[arg_idx]), ctrs)
}

/// Infers the type of a sequence of arguments
pub fn infer_type<'a>(
pats: impl IntoIterator<Item = &'a Pattern>,
Expand All @@ -19,9 +23,17 @@ fn unify(old: Type, new: Type) -> Result<Type, MatchErr> {
match (old, new) {
(Type::Any, new) => Ok(new),
(old, Type::Any) => Ok(old),

(Type::Adt(old), Type::Adt(new)) if new == old => Ok(Type::Adt(old)),

(Type::Num, Type::Num) => Ok(Type::Num),
(Type::Num, Type::NumSucc(n)) => Ok(Type::NumSucc(n)),

(Type::NumSucc(n), Type::Num) => Ok(Type::NumSucc(n)),
(Type::NumSucc(a), Type::NumSucc(b)) if a == b => Ok(Type::NumSucc(a)),

(Type::Tup, Type::Tup) => Ok(Type::Tup),

(old, new) => Err(MatchErr::TypeMismatch(new, old)),
}
}
1 change: 0 additions & 1 deletion src/term/check/unbound_pats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl Pattern {
}
Pattern::Lst(args) => args.iter().for_each(|arg| check.push(arg)),
Pattern::Var(_) | Pattern::Num(_) => {}
Pattern::Err => unreachable!(),
}
}
unbounds
Expand Down
16 changes: 7 additions & 9 deletions src/term/display.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use super::{
net_to_term::ReadbackError, Book, Definition, MatchNum, Name, Op, Pattern, Rule, Tag, Term, Type,
};
use super::{net_to_term::ReadbackError, Book, Definition, Name, NumCtr, Op, Pattern, Rule, Tag, Term, Type};
use std::{fmt, ops::Deref};

/* Some aux structures for things that are not so simple to display */
Expand Down Expand Up @@ -113,7 +111,6 @@ impl fmt::Display for Pattern {
Pattern::Num(num) => write!(f, "{num}"),
Pattern::Tup(fst, snd) => write!(f, "({}, {})", fst, snd,),
Pattern::Lst(pats) => write!(f, "[{}]", DisplayJoin(|| pats, ", ")),
Pattern::Err => write!(f, "<Invalid>"),
}
}
}
Expand Down Expand Up @@ -141,13 +138,13 @@ impl fmt::Display for Book {
}
}

impl fmt::Display for MatchNum {
impl fmt::Display for NumCtr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MatchNum::Zero => write!(f, "0"),
MatchNum::Succ(None) => write!(f, "+"),
MatchNum::Succ(Some(None)) => write!(f, "+*"),
MatchNum::Succ(Some(Some(nam))) => write!(f, "+{nam}"),
NumCtr::Num(n) => write!(f, "{n}"),
NumCtr::Succ(n, None) => write!(f, "{n}+"),
NumCtr::Succ(n, Some(None)) => write!(f, "{n}+*"),
NumCtr::Succ(n, Some(Some(nam))) => write!(f, "{n}+{nam}"),
}
}
}
Expand Down Expand Up @@ -188,6 +185,7 @@ impl fmt::Display for Type {
Type::Any => write!(f, "any"),
Type::Tup => write!(f, "tup"),
Type::Num => write!(f, "num"),
Type::NumSucc(n) => write!(f, "{n}+"),
Type::Adt(nam) => write!(f, "{nam}"),
}
}
Expand Down
136 changes: 92 additions & 44 deletions src/term/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use self::{check::type_check::infer_type, parser::lexer::STRINGS};
use self::{check::type_check::infer_match_arg_type, parser::lexer::STRINGS};
use crate::{diagnostics::Info, term::builtins::*, ENTRY_POINT};
use indexmap::{IndexMap, IndexSet};
use interner::global::GlobalString;
use itertools::Itertools;
use std::{borrow::Cow, collections::HashMap, ops::Deref};
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
ops::Deref,
};

pub mod builtins;
pub mod check;
Expand Down Expand Up @@ -136,21 +140,19 @@ pub enum Term {
Err,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Pattern {
Var(Option<Name>),
Ctr(Name, Vec<Pattern>),
Num(MatchNum),
Num(NumCtr),
Tup(Box<Pattern>, Box<Pattern>),
Lst(Vec<Pattern>),
#[default]
Err,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MatchNum {
Zero,
Succ(Option<Option<Name>>),
pub enum NumCtr {
Num(u64),
Succ(u64, Option<Option<Name>>),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -185,9 +187,15 @@ pub enum Op {
/// Pattern types.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Type {
/// Variables/wildcards.
Any,
/// A native tuple.
Tup,
/// A sequence of arbitrary numbers ending in a variable.
Num,
/// A strictly incrementing sequence of numbers starting from 0, ending in a + ctr.
NumSucc(u64),
/// Adt constructors declared with the `data` syntax.
Adt(Name),
}

Expand Down Expand Up @@ -239,6 +247,7 @@ impl Tag {
}

impl Term {
/* Common construction patterns */
pub fn lam(nam: Option<Name>, bod: Term) -> Self {
Term::Lam { tag: Tag::Static, nam, bod: Box::new(bod) }
}
Expand Down Expand Up @@ -289,6 +298,30 @@ impl Term {
Term::Str { val: STRINGS.get(str) }
}

pub fn native_num_match(arg: Term, zero: Term, succ: Term) -> Term {
let zero = Rule { pats: vec![Pattern::Num(NumCtr::Num(0))], body: zero };
let succ = Rule { pats: vec![Pattern::Num(NumCtr::Succ(1, None))], body: succ };
Term::Mat { args: vec![arg], rules: vec![zero, succ] }
}

pub fn sub_num(arg: Term, val: u64) -> Term {
if val == 0 {
arg
} else {
Term::Opx { op: Op::SUB, fst: Box::new(arg), snd: Box::new(Term::Num { val }) }
}
}

pub fn add_num(arg: Term, val: u64) -> Term {
if val == 0 {
arg
} else {
Term::Opx { op: Op::ADD, fst: Box::new(arg), snd: Box::new(Term::Num { val }) }
}
}

/* Common checks and transformations */

/// Substitute the occurrences of a variable in a term with the given term.
/// Caution: can cause invalid shadowing of variables if used incorrectly.
/// Ex: Using subst to beta-reduce (@a @b a b) converting it into (@b b).
Expand Down Expand Up @@ -507,7 +540,7 @@ impl Term {
return false;
}
// The match is over a valid type
let Ok(typ) = infer_type(rules.iter().map(|r| &r.pats[0]), ctrs) else {
let Ok(typ) = infer_match_arg_type(rules, 0, ctrs) else {
return false;
};
// The match has one arm for each constructor, matching the constructors in adt declaration order
Expand All @@ -529,14 +562,32 @@ impl Term {
}
}
Type::Num => {
if rules.len() != 2 {
return false;
let mut nums = HashSet::new();
for rule in rules {
if let Pattern::Num(NumCtr::Num(n)) = &rule.pats[0] {
if nums.contains(n) {
return false;
}
nums.insert(*n);
}
}
if !matches!(rules[0].pats[0], Pattern::Num(MatchNum::Zero)) {
}
Type::NumSucc(n) => {
if rules.len() as u64 != n + 1 {
return false;
}
if !matches!(rules[1].pats[0], Pattern::Num(MatchNum::Succ(Some(_)))) {
return false;
for (i, _) in rules.iter().enumerate() {
if i as u64 == n {
let Pattern::Num(NumCtr::Succ(n_pat, Some(_))) = &rules[i].pats[0] else { return false };
if n != *n_pat {
return false;
}
} else {
let Pattern::Num(NumCtr::Num(i_pat)) = &rules[i].pats[0] else { return false };
if i as u64 != *i_pat {
return false;
}
}
}
}
Type::Adt(adt) => {
Expand Down Expand Up @@ -570,7 +621,7 @@ impl Pattern {
pub fn bind_or_eras(&self) -> impl DoubleEndedIterator<Item = &Option<Name>> {
self.iter().filter_map(|pat| match pat {
Pattern::Var(nam) => Some(nam),
Pattern::Num(MatchNum::Succ(nam)) => nam.as_ref(),
Pattern::Num(NumCtr::Succ(_, nam)) => nam.as_ref(),
_ => None,
})
}
Expand All @@ -585,11 +636,10 @@ impl Pattern {
go(fst, set);
go(snd, set);
}
Pattern::Num(MatchNum::Succ(Some(nam))) => {
Pattern::Num(NumCtr::Succ(_, Some(nam))) => {
set.push(nam);
}
Pattern::Num(_) => {}
Pattern::Err => unreachable!(),
}
}
let mut set = Vec::new();
Expand All @@ -614,7 +664,6 @@ impl Pattern {
Pattern::Num(_) => Box::new([].iter()),
Pattern::Tup(fst, snd) => Box::new([fst.as_ref(), snd.as_ref()].into_iter()),
Pattern::Lst(els) => Box::new(els.iter()),
Pattern::Err => unreachable!(),
}
}

Expand All @@ -634,24 +683,24 @@ impl Pattern {
match self {
Pattern::Var(_) => None,
Pattern::Ctr(nam, _) => Some(nam.clone()),
Pattern::Num(MatchNum::Zero) => Some(Name::new("0")),
Pattern::Num(MatchNum::Succ(_)) => Some(Name::new("+")),
Pattern::Num(NumCtr::Num(num)) => Some(Name::new(format!("{num}"))),
Pattern::Num(NumCtr::Succ(num, _)) => Some(Name::new(format!("{num}+"))),
Pattern::Tup(_, _) => Some(Name::new("(,)")),
Pattern::Lst(_) => todo!(),
Pattern::Err => unreachable!(),
}
}

pub fn is_wildcard(&self) -> bool {
matches!(self, Pattern::Var(_))
}

pub fn is_detached_num_match(&self) -> bool {
if let Pattern::Num(num) = self {
match num {
MatchNum::Zero => true,
MatchNum::Succ(None) => true,
MatchNum::Succ(Some(_)) => false,
pub fn is_native_num_match(&self) -> bool {
if let Pattern::Num(ctr) = self {
match ctr {
NumCtr::Num(0) => true,
NumCtr::Num(_) => false,
NumCtr::Succ(1, None) => true,
NumCtr::Succ(_, _) => false,
}
} else {
false
Expand All @@ -667,7 +716,6 @@ impl Pattern {
Pattern::Tup(fst, snd) => {
matches!(fst.as_ref(), Pattern::Var(_)) && matches!(snd.as_ref(), Pattern::Var(_))
}
Pattern::Err => unreachable!(),
}
}

Expand All @@ -679,9 +727,9 @@ impl Pattern {
Type::Adt(adt_nam.clone())
}
Pattern::Tup(..) => Type::Tup,
Pattern::Num(..) => Type::Num,
Pattern::Num(NumCtr::Num(_)) => Type::Num,
Pattern::Num(NumCtr::Succ(n, _)) => Type::NumSucc(*n),
Pattern::Lst(..) => Type::Adt(builtins::LIST.into()),
Pattern::Err => unreachable!(),
}
}

Expand All @@ -691,31 +739,28 @@ impl Pattern {
Pattern::Ctr(ctr, args) => {
Term::call(Term::Ref { nam: ctr.clone() }, args.iter().map(|arg| arg.to_term()))
}
Pattern::Num(MatchNum::Zero) => Term::Num { val: 0 },
Pattern::Num(NumCtr::Num(val)) => Term::Num { val: *val },
// Succ constructor with no variable is not a valid term, only a compiler intermediate for a MAT inet node.
Pattern::Num(MatchNum::Succ(None)) => unreachable!(),
Pattern::Num(MatchNum::Succ(Some(Some(nam)))) => Term::Opx {
op: Op::ADD,
fst: Box::new(Term::Var { nam: nam.clone() }),
snd: Box::new(Term::Num { val: 1 }),
},
Pattern::Num(MatchNum::Succ(Some(None))) => Term::Era,
Pattern::Num(NumCtr::Succ(_, None)) => unreachable!(),
Pattern::Num(NumCtr::Succ(val, Some(Some(nam)))) => Term::add_num(Term::Var { nam: nam.clone() }, *val),
Pattern::Num(NumCtr::Succ(_, Some(None))) => Term::Era,
Pattern::Tup(fst, snd) => Term::Tup { fst: Box::new(fst.to_term()), snd: Box::new(snd.to_term()) },
Pattern::Lst(_) => {
let mut p = self.clone();
p.encode_builtins();
p.to_term()
}
Pattern::Err => unreachable!(),
}
}

/// True if both patterns are equal (match the same expressions) without considering nested patterns.
pub fn simple_equals(&self, other: &Pattern) -> bool {
match (self, other) {
(Pattern::Ctr(a, _), Pattern::Ctr(b, _)) if a == b => true,
(Pattern::Num(MatchNum::Zero), Pattern::Num(MatchNum::Zero)) => true,
(Pattern::Num(MatchNum::Succ(_)), Pattern::Num(MatchNum::Succ(_))) => true,
(Pattern::Num(NumCtr::Num(a)), Pattern::Num(NumCtr::Num(b))) if a == b => true,
(Pattern::Num(NumCtr::Num(_)), Pattern::Num(NumCtr::Num(_))) => false,
(Pattern::Num(NumCtr::Succ(a, _)), Pattern::Num(NumCtr::Succ(b, _))) if a == b => true,
(Pattern::Num(NumCtr::Succ(_, _)), Pattern::Num(NumCtr::Succ(_, _))) => false,
(Pattern::Tup(_, _), Pattern::Tup(_, _)) => true,
(Pattern::Lst(_), Pattern::Lst(_)) => true,
(Pattern::Var(_), Pattern::Var(_)) => true,
Expand Down Expand Up @@ -774,9 +819,12 @@ impl Type {
Box::new(Pattern::Var(Some("%fst".into()))),
Box::new(Pattern::Var(Some("%snd".into()))),
)],
Type::Num => {
vec![Pattern::Num(MatchNum::Zero), Pattern::Num(MatchNum::Succ(Some(Some("%pred".into()))))]
Type::NumSucc(n) => {
let mut ctrs = (0 .. *n).map(|n| Pattern::Num(NumCtr::Num(n))).collect::<Vec<_>>();
ctrs.push(Pattern::Num(NumCtr::Succ(*n, Some(Some("%pred".into())))));
ctrs
}
Type::Num => unreachable!(),
Type::Adt(adt) => {
// TODO: Should return just a ref to ctrs and not clone.
adts[adt]
Expand Down
Loading
Loading