Skip to content

Commit

Permalink
Implement limited type inference for numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Aug 14, 2024
1 parent 01a7fb4 commit c0fc97a
Show file tree
Hide file tree
Showing 10 changed files with 524 additions and 272 deletions.
4 changes: 2 additions & 2 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ impl<T> std::fmt::Display for Pattern<T> {
PatternEnum::UnsignedInclusiveRange(min, max, suffix) => {
if min == max {
f.write_fmt(format_args!("{min}{suffix}"))
} else if *min == 0 && *max == suffix.max() {
} else if *min == 0 && Some(*max) == suffix.max() {
f.write_str("_")
} else {
f.write_fmt(format_args!("{min}{suffix}..={max}{suffix}"))
Expand All @@ -494,7 +494,7 @@ impl<T> std::fmt::Display for Pattern<T> {
PatternEnum::SignedInclusiveRange(min, max, suffix) => {
if min == max {
f.write_fmt(format_args!("{min}{suffix}"))
} else if *min == suffix.min() && *max == suffix.max() {
} else if Some(*min) == suffix.min() && Some(*max) == suffix.max() {
f.write_str("_")
} else {
f.write_fmt(format_args!("{min}{suffix}..={max}{suffix}"))
Expand Down
318 changes: 269 additions & 49 deletions src/check.rs

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,8 @@ impl Type {
Type::Unsigned(UnsignedNumType::U16) | Type::Signed(SignedNumType::I16) => 16,
Type::Unsigned(UnsignedNumType::U32) | Type::Signed(SignedNumType::I32) => 32,
Type::Unsigned(UnsignedNumType::U64) | Type::Signed(SignedNumType::I64) => 64,
Type::Unsigned(UnsignedNumType::Unspecified)
| Type::Signed(SignedNumType::Unspecified) => 32,
Type::Array(elem, size) => elem.size_in_bits_for_defs(prg, const_sizes) * size,
Type::ArrayConst(elem, size) => {
elem.size_in_bits_for_defs(prg, const_sizes) * const_sizes.get(size).unwrap()
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
//! assert_eq!(u32::try_from(output).map_err(|e| e.prettify(&code)).unwrap(), 2 + 10 + 100);
//!
//! // Or we can run the compiled circuit in an MPC engine, simulated using `prg.circuit.eval()`:
//! let x = prg.parse_arg(0, "2u32").unwrap().as_bits();
//! let y = prg.parse_arg(1, "10u32").unwrap().as_bits();
//! let z = prg.parse_arg(2, "100u32").unwrap().as_bits();
//! let x = prg.parse_arg(0, "2").unwrap().as_bits();
//! let y = prg.parse_arg(1, "10").unwrap().as_bits();
//! let z = prg.parse_arg(2, "100").unwrap().as_bits();
//! let output = prg.circuit.eval(&[x, y, z]); // use your own MPC engine here instead
//! let result = prg.parse_output(&output).unwrap();
//! assert_eq!("112u32", result.to_string());
//! assert_eq!("112", result.to_string());
//!
//! // Input arguments can also be constructed directly as literals:
//! let x = prg.literal_arg(0, Literal::NumUnsigned(2, U32)).unwrap().as_bits();
Expand Down
6 changes: 3 additions & 3 deletions src/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ impl Display for Literal {
match self {
Literal::True => write!(f, "true"),
Literal::False => write!(f, "false"),
Literal::NumUnsigned(n, ty) => write!(f, "{n}{ty}"),
Literal::NumSigned(n, ty) => {
write!(f, "{n}{ty}")
Literal::NumUnsigned(n, _) => write!(f, "{n}"),
Literal::NumSigned(n, _) => {
write!(f, "{n}")
}
Literal::ArrayRepeat(elem, size) => write!(f, "[{elem}; {size}]"),
Literal::Array(elems) => {
Expand Down
33 changes: 24 additions & 9 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ impl Parser {
TokenEnum::ShlAssign => Some(Op::ShiftLeft),
_ => None,
};
println!("next: {next:?}");
if let Some(op) = op {
self.advance();
let value = self.parse_expr()?;
Expand Down Expand Up @@ -1172,14 +1171,14 @@ impl Parser {
};
while self.peek(&TokenEnum::LeftBracket) || self.peek(&TokenEnum::Dot) {
if self.next_matches(&TokenEnum::LeftBracket).is_some() {
if let Some(Token(TokenEnum::ConstantIndexOrSize(i), meta)) = self.tokens.peek() {
if let Some(Token(TokenEnum::UnsignedNum(i, UnsignedNumType::Unspecified), meta)) =
self.tokens.peek()
{
let i = *i;
let meta = *meta;
self.advance();
let index = Expr::untyped(
ExprEnum::NumUnsigned(i as u64, UnsignedNumType::Usize),
meta,
);
let index =
Expr::untyped(ExprEnum::NumUnsigned(i, UnsignedNumType::Usize), meta);
let end = self.expect(&TokenEnum::RightBracket)?;
let meta = join_meta(expr.meta, end);
expr =
Expand All @@ -1195,7 +1194,11 @@ impl Parser {
let peeked = self.tokens.peek();
if let Some(Token(TokenEnum::Identifier(_), _)) = peeked {
expr = self.parse_method_call_or_struct_access(expr)?;
} else if let Some(Token(TokenEnum::ConstantIndexOrSize(i), meta_index)) = peeked {
} else if let Some(Token(
TokenEnum::UnsignedNum(i, UnsignedNumType::Unspecified),
meta_index,
)) = peeked
{
let i = *i;
let meta_index = *meta_index;
self.advance();
Expand Down Expand Up @@ -1374,7 +1377,13 @@ impl Parser {
if self.peek(&TokenEnum::Semicolon) {
self.expect(&TokenEnum::Semicolon)?;
match self.tokens.peek().cloned() {
Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) => {
Some(Token(
TokenEnum::UnsignedNum(
n,
UnsignedNumType::Unspecified | UnsignedNumType::Usize,
),
_,
)) => {
self.advance();
let meta_end = self.expect(&TokenEnum::RightBracket)?;
let meta = join_meta(meta, meta_end);
Expand Down Expand Up @@ -1446,7 +1455,13 @@ impl Parser {
let (ty, _) = self.parse_type()?;
self.expect(&TokenEnum::Semicolon)?;
match self.tokens.peek().cloned() {
Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) => {
Some(Token(
TokenEnum::UnsignedNum(
n,
UnsignedNumType::Unspecified | UnsignedNumType::Usize,
),
_,
)) => {
self.advance();
let size = n as usize;
let meta_end = self.expect(&TokenEnum::RightBracket)?;
Expand Down
70 changes: 34 additions & 36 deletions src/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl<'a> Scanner<'a> {
"i64" if (i64::MIN..=i64::MAX).contains(&n) => {
SignedNumType::I64
}
"" => SignedNumType::Unspecified,
_ => {
self.push_error(ScanErrorEnum::InvalidUnsignedNum);
SignedNumType::I64
Expand All @@ -268,42 +269,39 @@ impl<'a> Scanner<'a> {
while let Some(char) = self.next_matches_alphanumeric() {
literal_suffix.push(char);
}
if literal_suffix.is_empty() && n <= u32::MAX as u64 {
self.push_token(TokenEnum::ConstantIndexOrSize(n as u32));
} else {
let token = match literal_suffix.as_str() {
"i8" if n <= i8::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I8)
}
"i16" if n <= i16::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I16)
}
"i32" if n <= i32::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I32)
}
"i64" if n <= i64::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I64)
}
"usize" if n <= usize::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::Usize)
}
"u8" if n <= u8::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::U8)
}
"u16" if n <= u16::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::U16)
}
"u32" if n <= u32::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::U32)
}
"u64" => TokenEnum::UnsignedNum(n, UnsignedNumType::U64),
_ => {
self.push_error(ScanErrorEnum::InvalidUnsignedNum);
TokenEnum::UnsignedNum(n, UnsignedNumType::U64)
}
};
self.push_token(token);
}
let token = match literal_suffix.as_str() {
"i8" if n <= i8::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I8)
}
"i16" if n <= i16::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I16)
}
"i32" if n <= i32::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I32)
}
"i64" if n <= i64::MAX as u64 => {
TokenEnum::SignedNum(n as i64, SignedNumType::I64)
}
"usize" if n <= usize::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::Usize)
}
"u8" if n <= u8::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::U8)
}
"u16" if n <= u16::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::U16)
}
"u32" if n <= u32::MAX as u64 => {
TokenEnum::UnsignedNum(n, UnsignedNumType::U32)
}
"u64" => TokenEnum::UnsignedNum(n, UnsignedNumType::U64),
"" => TokenEnum::UnsignedNum(n, UnsignedNumType::Unspecified),
_ => {
self.push_error(ScanErrorEnum::InvalidUnsignedNum);
TokenEnum::UnsignedNum(n, UnsignedNumType::U64)
}
};
self.push_token(token);
} else {
self.push_error(ScanErrorEnum::InvalidUnsignedNum);
}
Expand Down
44 changes: 25 additions & 19 deletions src/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ pub struct Token(pub TokenEnum, pub MetaInfo);
pub enum TokenEnum {
/// Identifier of alphanumeric chars.
Identifier(String),
/// Constant unsigned number used to index arrays / tuples.
ConstantIndexOrSize(u32),
/// Unsigned number.
UnsignedNum(u64, UnsignedNumType),
/// Signed number.
Expand Down Expand Up @@ -140,7 +138,6 @@ impl std::fmt::Display for TokenEnum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TokenEnum::Identifier(s) => f.write_str(s),
TokenEnum::ConstantIndexOrSize(num) => f.write_fmt(format_args!("{num}")),
TokenEnum::UnsignedNum(num, suffix) => f.write_fmt(format_args!("{num}{suffix}")),
TokenEnum::SignedNum(num, suffix) => f.write_fmt(format_args!("{num}{suffix}")),
TokenEnum::KeywordConst => f.write_str("const"),
Expand Down Expand Up @@ -219,17 +216,20 @@ pub enum UnsignedNumType {
U32,
/// 64-bit unsigned integer type.
U64,
/// No type suffix has been specified, could be any from i8 to i64.
Unspecified,
}

impl UnsignedNumType {
/// Returns the max value representable by this type.
pub fn max(&self) -> u64 {
pub fn max(&self) -> Option<u64> {
match self {
UnsignedNumType::Usize => u32::MAX as u64,
UnsignedNumType::U8 => u8::MAX as u64,
UnsignedNumType::U16 => u16::MAX as u64,
UnsignedNumType::U32 => u32::MAX as u64,
UnsignedNumType::U64 => u64::MAX,
UnsignedNumType::Usize => Some(u32::MAX as u64),
UnsignedNumType::U8 => Some(u8::MAX as u64),
UnsignedNumType::U16 => Some(u16::MAX as u64),
UnsignedNumType::U32 => Some(u32::MAX as u64),
UnsignedNumType::U64 => Some(u64::MAX),
UnsignedNumType::Unspecified => None,
}
}
}
Expand All @@ -242,6 +242,7 @@ impl std::fmt::Display for UnsignedNumType {
UnsignedNumType::U16 => "u16",
UnsignedNumType::U32 => "u32",
UnsignedNumType::U64 => "u64",
UnsignedNumType::Unspecified => "unspecified unsigned int",
})
}
}
Expand All @@ -258,26 +259,30 @@ pub enum SignedNumType {
I32,
/// 64-bit signed integer type.
I64,
/// No type suffix has been specified, could be any from i8 to i64.
Unspecified,
}

impl SignedNumType {
/// Returns the minimum value representable by this type.
pub fn min(&self) -> i64 {
pub fn min(&self) -> Option<i64> {
match self {
SignedNumType::I8 => i8::MIN as i64,
SignedNumType::I16 => i16::MIN as i64,
SignedNumType::I32 => i32::MIN as i64,
SignedNumType::I64 => i64::MIN,
SignedNumType::I8 => Some(i8::MIN as i64),
SignedNumType::I16 => Some(i16::MIN as i64),
SignedNumType::I32 => Some(i32::MIN as i64),
SignedNumType::I64 => Some(i64::MIN),
SignedNumType::Unspecified => None,
}
}

/// Returns the maximum value representable by this type.
pub fn max(&self) -> i64 {
pub fn max(&self) -> Option<i64> {
match self {
SignedNumType::I8 => i8::MAX as i64,
SignedNumType::I16 => i16::MAX as i64,
SignedNumType::I32 => i32::MAX as i64,
SignedNumType::I64 => i64::MAX,
SignedNumType::I8 => Some(i8::MAX as i64),
SignedNumType::I16 => Some(i16::MAX as i64),
SignedNumType::I32 => Some(i32::MAX as i64),
SignedNumType::I64 => Some(i64::MAX),
SignedNumType::Unspecified => None,
}
}
}
Expand All @@ -289,6 +294,7 @@ impl std::fmt::Display for SignedNumType {
SignedNumType::I16 => "i16",
SignedNumType::I32 => "i32",
SignedNumType::I64 => "i64",
SignedNumType::Unspecified => "unspecified signed int",
})
}
}
Expand Down
Loading

0 comments on commit c0fc97a

Please sign in to comment.