Skip to content

Commit

Permalink
numeric casting operations (#361)
Browse files Browse the repository at this point in the history
Adds support for numeric casting operations, `[u24]`, `[i24]`, `[f24]`.

Fixes #322.

Casts are always defined, and spiritually follow rusts numeric casting semantics:
- `i24 <-> u24` is just reinterpretation of bits.
- `f24 -> i24` or `f24 -> u24` casts to the "closest" integer representing this float, saturating if out of range and `0` if `NaN`.
- `i24 -> f24` or `u24 -> f24` casts to the "closest" float representing this integer.
  • Loading branch information
enricozb authored Jun 3, 2024
2 parents 166a3cb + f080a3a commit f158106
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 8 deletions.
19 changes: 19 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,22 @@ impl<'i> CoreParser<'i> {
pub fn parse_numb_sym(&mut self) -> Result<Numb, String> {
self.consume("[")?;

// numeric casts
if let Some(cast) = match () {
_ if self.try_consume("u24") => Some(hvm::TY_U24),
_ if self.try_consume("i24") => Some(hvm::TY_I24),
_ if self.try_consume("f24") => Some(hvm::TY_F24),
_ => None
} {
// Casts can't be partially applied, so nothing should follow.
self.consume("]")?;

return Ok(Numb(hvm::Numb::new_sym(cast).0));
}

// Parses the symbol
let op = hvm::Numb::new_sym(match () {
// numeric operations
_ if self.try_consume("+") => hvm::OP_ADD,
_ if self.try_consume("-") => hvm::OP_SUB,
_ if self.try_consume(":-") => hvm::FP_SUB,
Expand Down Expand Up @@ -224,6 +238,11 @@ impl Numb {
let numb = hvm::Numb(self.0);
match numb.get_typ() {
hvm::TY_SYM => match numb.get_sym() as hvm::Tag {
// casts
hvm::TY_U24 => "[u24]".to_string(),
hvm::TY_I24 => "[i24]".to_string(),
hvm::TY_F24 => "[f24]".to_string(),
// operations
hvm::OP_ADD => "[+]".to_string(),
hvm::OP_SUB => "[-]".to_string(),
hvm::FP_SUB => "[:-]".to_string(),
Expand Down
78 changes: 76 additions & 2 deletions src/hvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ typedef uint16_t u16;
typedef int32_t i32;
typedef uint32_t u32;
typedef uint64_t u64;
typedef float f32;
typedef double f64;

typedef _Atomic(u8) a8;
typedef _Atomic(u16) a16;
Expand Down Expand Up @@ -75,6 +77,10 @@ typedef u32 Numb; // Numb ::= 29-bit (rounded up to u32)
#define SWIT 0x7

// Numbers
static const f32 U24_MAX = (f32) (1 << 24) - 1;
static const f32 U24_MIN = 0.0;
static const f32 I24_MAX = (f32) (1 << 23) - 1;
static const f32 I24_MIN = (f32) (i32) ((-1u) << 23);
#define TY_SYM 0x00
#define TY_U24 0x01
#define TY_I24 0x02
Expand Down Expand Up @@ -278,10 +284,15 @@ static inline void swap(Port *a, Port *b) {
Port x = *a; *a = *b; *b = x;
}

u32 min(u32 a, u32 b) {
inline u32 min(u32 a, u32 b) {
return (a < b) ? a : b;
}

inline f32 clamp(f32 x, f32 min, f32 max) {
const f32 t = x < min ? min : x;
return (t > max) ? max : t;
}

// A simple spin-wait barrier using atomic operations
a64 a_reached = 0; // number of threads that reached the current barrier
a64 a_barrier = 0; // number of barriers passed during this program
Expand Down Expand Up @@ -429,18 +440,76 @@ static inline Tag get_typ(Numb word) {
return word & 0x1F;
}

static inline bool is_num(Numb word) {
return get_typ(word) >= TY_U24 && get_typ(word) <= TY_F24;
}

static inline bool is_cast(Numb word) {
return get_typ(word) == TY_SYM && get_sym(word) >= TY_U24 && get_sym(word) <= TY_F24;
}

// Partial application
static inline Numb partial(Numb a, Numb b) {
return (b & ~0x1F) | get_sym(a);
}

// Cast a number to another type.
// The semantics are meant to spiritually resemble rust's numeric casts:
// - i24 <-> u24: is just reinterpretation of bits
// - f24 -> i24,
// f24 -> u24: casts to the "closest" integer representing this float,
// saturating if out of range and 0 if NaN
// - i24 -> f24,
// u24 -> f24: casts to the "closest" float representing this integer.
static inline Numb cast(Numb a, Numb b) {
if (get_sym(a) == TY_U24 && get_typ(b) == TY_U24) return b;
if (get_sym(a) == TY_U24 && get_typ(b) == TY_I24) {
// reinterpret bits
i32 val = get_i24(b);
return new_u24(*(u32*) &val);
}
if (get_sym(a) == TY_U24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_u24(0);
}
return new_u24((u32) clamp(val, U24_MIN, U24_MAX));
}

if (get_sym(a) == TY_I24 && get_typ(b) == TY_U24) {
// reinterpret bits
u32 val = get_u24(b);
return new_i24(*(i32*) &val);
}
if (get_sym(a) == TY_I24 && get_typ(b) == TY_I24) return b;
if (get_sym(a) == TY_I24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_i24(0);
}
return new_i24((i32) clamp(val, I24_MIN, I24_MAX));
}

if (get_sym(a) == TY_F24 && get_typ(b) == TY_U24) return new_f24((f32) get_u24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_I24) return new_f24((f32) get_i24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_F24) return b;

return new_u24(0);
}

// Operate function
static inline Numb operate(Numb a, Numb b) {
Tag at = get_typ(a);
Tag bt = get_typ(b);
if (at == TY_SYM && bt == TY_SYM) {
return new_u24(0);
}
if (is_cast(a) && is_num(b)) {
return cast(a, b);
}
if (is_cast(b) && is_num(a)) {
return cast(b, a);
}
if (at == TY_SYM && bt != TY_SYM) {
return partial(a, b);
}
Expand Down Expand Up @@ -1916,6 +1985,11 @@ void pretty_print_numb(Numb word) {
switch (get_typ(word)) {
case TY_SYM: {
switch (get_sym(word)) {
// types
case TY_U24: printf("[u24]"); break;
case TY_I24: printf("[i24]"); break;
case TY_F24: printf("[f24]"); break;
// operations
case OP_ADD: printf("[+]"); break;
case OP_SUB: printf("[-]"); break;
case FP_SUB: printf("[:-]"); break;
Expand Down Expand Up @@ -1957,7 +2031,7 @@ void pretty_print_numb(Numb word) {
} else if (isnan(get_f24(word))) {
printf("+NaN");
} else {
printf("%f", get_f24(word));
printf("%.7e", get_f24(word));
}
break;
}
Expand Down
70 changes: 69 additions & 1 deletion src/hvm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ __global__ void print_heatmap(GNet* gnet, u32 turn);
// Utils
// -----

__device__ __host__ f32 clamp(f32 x, f32 min, f32 max) {
const f32 t = x < min ? min : x;
return (t > max) ? max : t;
}

// TODO: write a time64() function that returns the time as fast as possible as a u64
static inline u64 time64() {
struct timespec ts;
Expand Down Expand Up @@ -541,6 +546,58 @@ __device__ __host__ inline Tag get_typ(Numb word) {
return word & 0x1F;
}

__device__ __host__ inline bool is_num(Numb word) {
return get_typ(word) >= TY_U24 && get_typ(word) <= TY_F24;
}

__device__ __host__ inline bool is_cast(Numb word) {
return get_typ(word) == TY_SYM && get_sym(word) >= TY_U24 && get_sym(word) <= TY_F24;
}

// Cast a number to another type.
// The semantics are meant to spiritually resemble rust's numeric casts:
// - i24 <-> u24: is just reinterpretation of bits
// - f24 -> i24,
// f24 -> u24: casts to the "closest" integer representing this float,
// saturating if out of range and 0 if NaN
// - i24 -> f24,
// u24 -> f24: casts to the "closest" float representing this integer.
__device__ __host__ inline Numb cast(Numb a, Numb b) {
if (get_sym(a) == TY_U24 && get_typ(b) == TY_U24) return b;
if (get_sym(a) == TY_U24 && get_typ(b) == TY_I24) {
// reinterpret bits
i32 val = get_i24(b);
return new_u24(*(u32*) &val);
}
if (get_sym(a) == TY_U24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_u24(0);
}
return new_u24((u32) clamp(val, 0.0, 16777215));
}

if (get_sym(a) == TY_I24 && get_typ(b) == TY_U24) {
// reinterpret bits
u32 val = get_u24(b);
return new_i24(*(i32*) &val);
}
if (get_sym(a) == TY_I24 && get_typ(b) == TY_I24) return b;
if (get_sym(a) == TY_I24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_i24(0);
}
return new_i24((i32) clamp(val, -8388608.0, 8388607.0));
}

if (get_sym(a) == TY_F24 && get_typ(b) == TY_U24) return new_f24((f32) get_u24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_I24) return new_f24((f32) get_i24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_F24) return b;

return new_u24(0);
}

// Partial application
__device__ __host__ inline Numb partial(Numb a, Numb b) {
return (b & ~0x1F) | get_sym(a);
Expand All @@ -553,6 +610,12 @@ __device__ __host__ inline Numb operate(Numb a, Numb b) {
if (at == TY_SYM && bt == TY_SYM) {
return new_u24(0);
}
if (is_cast(a) && is_num(b)) {
return cast(a, b);
}
if (is_cast(b) && is_num(a)) {
return cast(b, a);
}
if (at == TY_SYM && bt != TY_SYM) {
return partial(a, b);
}
Expand Down Expand Up @@ -2403,6 +2466,11 @@ __device__ void pretty_print_numb(Numb word) {
switch (get_typ(word)) {
case TY_SYM: {
switch (get_sym(word)) {
// types
case TY_U24: printf("[u24]"); break;
case TY_I24: printf("[i24]"); break;
case TY_F24: printf("[f24]"); break;
// operations
case OP_ADD: printf("[+]"); break;
case OP_SUB: printf("[-]"); break;
case FP_SUB: printf("[:-]"); break;
Expand Down Expand Up @@ -2444,7 +2512,7 @@ __device__ void pretty_print_numb(Numb word) {
} else if (isnan(get_f24(word))) {
printf("+NaN");
} else {
printf("%f", get_f24(word));
printf("%.7e", get_f24(word));
}
break;
}
Expand Down
49 changes: 45 additions & 4 deletions src/hvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ pub struct APair(pub AtomicU64);

// Number
pub struct Numb(pub Val);
const U24_MAX : u32 = (1 << 24) - 1;
const U24_MIN : u32 = 0;
const I24_MAX : i32 = (1 << 23) - 1;
const I24_MIN : i32 = (-1) << 23;

// Tags
pub const VAR : Tag = 0x0; // variable
Expand Down Expand Up @@ -266,16 +270,47 @@ impl Numb {
}

// Gets the numeric type.

pub fn get_typ(&self) -> Tag {
return (self.0 & 0x1F) as Tag;
(self.0 & 0x1F) as Tag
}

pub fn is_num(&self) -> bool {
self.get_typ() >= TY_U24 && self.get_typ() <= TY_F24
}

// Flip flag.
pub fn is_cast(&self) -> bool {
self.get_typ() == TY_SYM && self.get_sym() >= TY_U24 && self.get_sym() <= TY_F24
}

// Partial application.
pub fn partial(a: Self, b: Self) -> Self {
return Numb((b.0 & !0x1F) | a.get_sym() as u32);
Numb((b.0 & !0x1F) | a.get_sym() as u32)
}

// Cast a number to another type.
// The semantics are meant to spiritually resemble rust's numeric casts:
// - i24 <-> u24: is just reinterpretation of bits
// - f24 -> i24,
// f24 -> u24: casts to the "closest" integer representing this float,
// saturating if out of range and 0 if NaN
// - i24 -> f24,
// u24 -> f24: casts to the "closest" float representing this integer.
pub fn cast(a: Self, b: Self) -> Self {
match (a.get_sym(), b.get_typ()) {
(TY_U24, TY_U24) => b,
(TY_U24, TY_I24) => Self::new_u24(b.get_i24() as u32),
(TY_U24, TY_F24) => Self::new_u24((b.get_f24() as u32).clamp(U24_MIN, U24_MAX)),

(TY_I24, TY_U24) => Self::new_i24(b.get_u24() as i32),
(TY_I24, TY_I24) => b,
(TY_I24, TY_F24) => Self::new_i24((b.get_f24() as i32).clamp(I24_MIN, I24_MAX)),

(TY_F24, TY_U24) => Self::new_f24(b.get_u24() as f32),
(TY_F24, TY_I24) => Self::new_f24(b.get_i24() as f32),
(TY_F24, TY_F24) => b,
// invalid cast
(_, _) => Self::new_u24(0),
}
}

pub fn operate(a: Self, b: Self) -> Self {
Expand All @@ -285,6 +320,12 @@ impl Numb {
if at == TY_SYM && bt == TY_SYM {
return Numb::new_u24(0);
}
if a.is_cast() && b.is_num() {
return Numb::cast(a, b);
}
if b.is_cast() && a.is_num() {
return Numb::cast(b, a);
}
if at == TY_SYM && bt != TY_SYM {
return Numb::partial(a, b);
}
Expand Down
45 changes: 45 additions & 0 deletions tests/programs/numeric-casts.hvm
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
@main = x & @tu0 ~ (* x)

// casting to u24
@tu0 = (* {n x}) & @tu1 ~ (* x) & 0 ~ $([u24] n) // 0
@tu1 = (* {n x}) & @tu2 ~ (* x) & 1234 ~ $([u24] n) // 1234
@tu2 = (* {n x}) & @tu3 ~ (* x) & +4321 ~ $([u24] n) // 4321
@tu3 = (* {n x}) & @tu4 ~ (* x) & -5678 ~ $([u24] n) // 16771538 (reinterprets bits)
@tu4 = (* {n x}) & @tu5 ~ (* x) & 2.8 ~ $([u24] n) // 2 (rounds to zero)
@tu5 = (* {n x}) & @tu6 ~ (* x) & -12.5 ~ $([u24] n) // 0 (saturates)
@tu6 = (* {n x}) & @tu7 ~ (* x) & 16777216.0 ~ $([u24] n) // 16777215 (saturates)
@tu7 = (* {n x}) & @tu8 ~ (* x) & +inf ~ $([u24] n) // 16777215 (saturates)
@tu8 = (* {n x}) & @tu9 ~ (* x) & -inf ~ $([u24] n) // 0 (saturates)
@tu9 = (* {n x}) & @ti0 ~ (* x) & +NaN ~ $([u24] n) // 0

// casting to i24
@ti0 = (* {n x}) & @ti1 ~ (* x) & 0 ~ $([i24] n) // +0
@ti1 = (* {n x}) & @ti2 ~ (* x) & 1234 ~ $([i24] n) // +1234
@ti2 = (* {n x}) & @ti3 ~ (* x) & +4321 ~ $([i24] n) // +4321
@ti3 = (* {n x}) & @ti4 ~ (* x) & -5678 ~ $([i24] n) // -5678
@ti4 = (* {n x}) & @ti5 ~ (* x) & 2.8 ~ $([i24] n) // +2 (rounds to zero)
@ti5 = (* {n x}) & @ti6 ~ (* x) & -12.7 ~ $([i24] n) // -12 (rounds to zero)
@ti6 = (* {n x}) & @ti7 ~ (* x) & 8388610.0 ~ $([i24] n) // +8388607 (saturates)
@ti7 = (* {n x}) & @ti8 ~ (* x) & -8388610.0 ~ $([i24] n) // -8388608 (saturates)
@ti8 = (* {n x}) & @ti9 ~ (* x) & +inf ~ $([i24] n) // +8388607 (saturates)
@ti9 = (* {n x}) & @ti10 ~ (* x) & -inf ~ $([i24] n) // -8388608 (saturates)
@ti10 = (* {n x}) & @tf0 ~ (* x) & +NaN ~ $([i24] n) // +0

// casting to f24
@tf0 = (* {n x}) & @tf1 ~ (* x) & +NaN ~ $([f24] n) // +NaN
@tf1 = (* {n x}) & @tf2 ~ (* x) & +inf ~ $([f24] n) // +inf
@tf2 = (* {n x}) & @tf3 ~ (* x) & -inf ~ $([f24] n) // -inf
@tf3 = (* {n x}) & @tf4 ~ (* x) & 2.15 ~ $([f24] n) // 2.15
@tf4 = (* {n x}) & @tf5 ~ (* x) & -2.15 ~ $([f24] n) // -2.15
@tf5 = (* {n x}) & @tf6 ~ (* x) & 0.15 ~ $([f24] n) // 0.15
@tf6 = (* {n x}) & @tf7 ~ (* x) & -1234 ~ $([f24] n) // -1234.0
@tf7 = (* {n x}) & @tf8 ~ (* x) & +1234 ~ $([f24] n) // +1234.0
@tf8 = (* {n x}) & @tf9 ~ (* x) & 123456 ~ $([f24] n) // 123456.0
@tf9 = (* {n x}) & @tp0 ~ (* x) & 16775982 ~ $([f24] n) // 16775936.0

// printing
@tp0 = (* {n x}) & @tp1 ~ (* x) & n ~ [u24] // [u24]
@tp1 = (* {n x}) & @tp2 ~ (* x) & n ~ [i24] // [i24]
@tp2 = (* {n x}) & @t ~ (* x) & n ~ [f24] // [f24]

@t = *
Loading

0 comments on commit f158106

Please sign in to comment.