Skip to content

Commit

Permalink
Add fp16 addition and sqrt()
Browse files Browse the repository at this point in the history
  • Loading branch information
romainguy committed Jun 17, 2022
1 parent 22bbf07 commit 9ae803a
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 33 deletions.
122 changes: 117 additions & 5 deletions src/commonMain/kotlin/dev/romainguy/kotlin/math/Half.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* limitations under the License.
*/

// Operators +, *, / based on http://half.sourceforge.net/ by Christian Rau
// and licensed under MIT

@file:Suppress("NOTHING_TO_INLINE")

package dev.romainguy.kotlin.math
Expand Down Expand Up @@ -344,7 +347,7 @@ value class Half(private val v: UShort) : Comparable<Half> {
get() = when {
isNaN() -> NaN
isInfinite() -> POSITIVE_INFINITY
// 0x7bff == MAX_VALUE
// 0x7bff == MAX_VALUE, return 2^4
v.toInt() and FP16_ABS == 0x7bff -> Half(0x4c00.toUShort())
else -> {
val d = absoluteValue
Expand Down Expand Up @@ -465,7 +468,7 @@ value class Half(private val v: UShort) : Comparable<Half> {
fun nextUp(): Half = when {
isNaN() || v == POSITIVE_INFINITY.v -> this
isZero() -> MIN_VALUE
else -> fromBits(toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) 1 else -1)
else -> Half((toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) 1 else -1).toUShort())
}

/**
Expand All @@ -474,7 +477,7 @@ value class Half(private val v: UShort) : Comparable<Half> {
fun nextDown(): Half = when {
isNaN() || v == NEGATIVE_INFINITY.v -> this
isZero() -> -MIN_VALUE
else -> fromBits(toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) -1 else 1)
else -> Half((toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) -1 else 1).toUShort())
}

/**
Expand Down Expand Up @@ -519,7 +522,73 @@ value class Half(private val v: UShort) : Comparable<Half> {
operator fun unaryPlus() = Half(v)

operator fun plus(other: Half): Half {
TODO("Not yet implemented")
val xbits = toBits()
val ybits = other.toBits()

val sub = ((xbits xor ybits) and FP16_SIGN_MASK) != 0

var ax = xbits and FP16_ABS
var ay = ybits and FP16_ABS

// Handle NaNs and infinities
if (ax >= FP16_EXPONENT_MAX || ay >= FP16_EXPONENT_MAX) {
return Half((
if (ax > FP16_EXPONENT_MAX || ay > FP16_EXPONENT_MAX) quiet(ax, ay)
else if (ay != FP16_EXPONENT_MAX) xbits
else if (sub && ax == FP16_EXPONENT_MAX) FP16_QUIET_NAN
else ybits
).toUShort())
}

// Handle zero operands, including signs
if (ax == 0) return if (ay != 0) other else Half((xbits and ybits).toUShort())
if (ay == 0) return this

// Compute the sign of the result
val s = (if (sub && ay > ax) ybits else xbits) and FP16_SIGN_MASK

if (ay > ax) {
val t = ax
ax = ay
ay = t
}

var e = (ax shr 10) + if (ax <= FP16_SIGNIFICAND_MASK) 1 else 0
val d = e - (ay shr 10) - if (ay <= FP16_SIGNIFICAND_MASK) 1 else 0

var mx = ((ax and FP16_SIGNIFICAND_MASK) or
((if (ax > FP16_SIGNIFICAND_MASK) 1 else 0) shl 10)) shl 3
var my: Int

if (d < 13) {
my = ((ay and FP16_SIGNIFICAND_MASK) or
((if (ay > FP16_SIGNIFICAND_MASK) 1 else 0) shl 10)) shl 3
my = (my shr d) or (if ((my and ((1 shl d) - 1)) != 0) 1 else 0)
} else {
my = 1
}

if (sub) {
mx -= my
if (mx == 0) return POSITIVE_ZERO
while (mx < 0x2000 && e > 1) {
mx = mx shl 1
e--
}
} else {
mx += my
val i = mx shr 14
e += i
if (e > 30) return Half((s or FP16_EXPONENT_MAX).toUShort())
mx = (mx shr i) or (mx and i)
}

// Guard and sticky bits
val v = s +((e - 1) shl 10) + (mx shr 3)
val G = (mx shr 2) and 1
val S = if (mx and 0x3 != 0) 1 else 0

return Half((v + (G and (S or v))).toUShort())
}

operator fun minus(other: Half) = this + (-other)
Expand Down Expand Up @@ -704,7 +773,50 @@ value class Half(private val v: UShort) : Comparable<Half> {
}
}

fun sqrt(x: Half): Half = TODO("Not implemented yet")
fun sqrt(x: Half): Half {
val bits = x.toBits()
var a = bits and FP16_ABS
var e = 15

if (a == 0 || a >= FP16_EXPONENT_MAX) {
return Half((when {
a > FP16_EXPONENT_MAX -> quiet(bits)
bits > FP16_SIGN_MASK -> FP16_QUIET_NAN
else -> bits
}).toUShort())
}

while (a < 0x400) {
a = a shl 1
e--
}

// Bring back 1.
var r = ((a and FP16_SIGNIFICAND_MASK) or 0x400).toUInt() shl 10
e += a shr 10
val i = e and 1
r = r shl i
e = (e - i) / 2

var m = 0U
var b = 1U shl 20
while (b != 0U) {
if (r < m + b) {
m = m shr 1
} else {
r -= m + b
m = (m shr 1) + b
}
b = b shr 2
}

// Guard and sticky bits
val v = (e shl 10).toUInt() + (m and 0x3ffU)
val G = if (r > m) 1U else 0U
val S = if (r != 0U) 1U else 0U

return Half((v + (G and (S or v))).toUShort())
}

/**
* Returns the absolute value of the specified half-precision float.
Expand Down
4 changes: 2 additions & 2 deletions src/commonMain/kotlin/dev/romainguy/kotlin/math/Scalar.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ const val INV_PI = 1.0f / PI
const val INV_TWO_PI = INV_PI * 0.5f
const val INV_FOUR_PI = INV_PI * 0.25f

val HALF_ONE = Half(1.0f)
val HALF_TWO = Half(2.0f)
val HALF_ONE = Half(0x3c00.toUShort())
val HALF_TWO = Half(0x4000.toUShort())

inline fun clamp(x: Float, min: Float, max: Float) = if (x < min) min else (if (x > max) max else x)

Expand Down
139 changes: 113 additions & 26 deletions src/commonTest/kotlin/dev/romainguy/kotlin/math/HalfTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,17 @@ class HalfTest {

assertEquals(Half.MIN_VALUE, Half.NEGATIVE_ZERO.nextUp())
assertEquals(-Half.MIN_VALUE, Half.NEGATIVE_ZERO.nextDown())

assertTrue(Half.NaN.nextTowards(HALF_TWO).isNaN())
assertTrue(HALF_TWO.nextTowards(Half.NaN).isNaN())
assertEquals(HALF_ONE, HALF_ONE.nextTowards(HALF_ONE))
assertEquals(-HALF_ONE, (-HALF_ONE).nextTowards(-HALF_ONE))

assertEquals(Half(1025.0f), Half(1024.0f).nextTowards(Half(32768.0f)))
assertEquals(Half(1023.5f), Half(1024.0f).nextTowards(Half(-32768.0f)))

assertEquals(Half(0.50048830f), Half(0.5f).nextTowards(Half(32768.0f)))
assertEquals(Half(0.49975586f), Half(0.5f).nextTowards(Half(-32768.0f)))
}

@Test
Expand All @@ -549,8 +560,8 @@ class HalfTest {

@Test
fun multiplication() {
assertTrue((Half(2.0f) * Half.NaN).isNaN())
assertTrue((Half.NaN * Half(2.0f)).isNaN())
assertTrue((HALF_TWO * Half.NaN).isNaN())
assertTrue((Half.NaN * HALF_TWO).isNaN())
assertTrue((Half.POSITIVE_INFINITY * Half.NaN).isNaN())
assertTrue((Half.NaN * Half.POSITIVE_INFINITY).isNaN())
assertTrue((Half.NEGATIVE_INFINITY * Half.NaN).isNaN())
Expand All @@ -560,29 +571,29 @@ class HalfTest {
assertTrue((Half.NEGATIVE_ZERO * Half.NaN).isNaN())
assertTrue((Half.NaN * Half.NEGATIVE_ZERO).isNaN())

assertTrue((Half(2.0f) * Half.POSITIVE_INFINITY).isInfinite())
assertTrue((Half.POSITIVE_INFINITY * Half(2.0f)).isInfinite())
assertTrue((HALF_TWO * Half.POSITIVE_INFINITY).isInfinite())
assertTrue((Half.POSITIVE_INFINITY * HALF_TWO).isInfinite())

assertTrue((Half(2.0f) * Half.NEGATIVE_INFINITY).isInfinite())
assertTrue((Half.NEGATIVE_INFINITY * Half(2.0f)).isInfinite())
assertTrue((HALF_TWO * Half.NEGATIVE_INFINITY).isInfinite())
assertTrue((Half.NEGATIVE_INFINITY * HALF_TWO).isInfinite())

assertTrue((Half(2.0f) * Half.POSITIVE_ZERO).isZero())
assertTrue((Half.POSITIVE_ZERO * Half(2.0f)).isZero())
assertTrue((HALF_TWO * Half.POSITIVE_ZERO).isZero())
assertTrue((Half.POSITIVE_ZERO * HALF_TWO).isZero())

assertTrue((Half(2.0f) * Half.NEGATIVE_ZERO).isZero())
assertTrue((Half.NEGATIVE_ZERO * Half(2.0f)).isZero())
assertTrue((HALF_TWO * Half.NEGATIVE_ZERO).isZero())
assertTrue((Half.NEGATIVE_ZERO * HALF_TWO).isZero())

// Overflow
assertEquals(Half.POSITIVE_INFINITY, Half(2.0f) * Half.MAX_VALUE)
assertEquals(Half.POSITIVE_INFINITY, Half.MAX_VALUE * Half(2.0f))
assertEquals(Half.POSITIVE_INFINITY, HALF_TWO * Half.MAX_VALUE)
assertEquals(Half.POSITIVE_INFINITY, Half.MAX_VALUE * HALF_TWO)
assertEquals(Half.NEGATIVE_INFINITY, Half(-2.0f) * Half.MAX_VALUE)
assertEquals(Half.NEGATIVE_INFINITY, Half.MAX_VALUE * Half(-2.0f))

// Underflow
assertEquals(Half.POSITIVE_ZERO, Half.MIN_VALUE * Half.MIN_NORMAL)
assertEquals(Half.NEGATIVE_ZERO, Half.MIN_VALUE * -Half.MIN_NORMAL)

assertEquals(Half(8.0f), Half(2.0f) * Half(4.0f))
assertEquals(Half(8.0f), HALF_TWO * Half(4.0f))
assertEquals(Half(2.88f), Half(1.2f) * Half(2.4f))
assertEquals(Half(-2.88f), Half(1.2f) * Half(-2.4f))
assertEquals(Half(-2.88f), Half(-1.2f) * Half(2.4f))
Expand All @@ -597,8 +608,8 @@ class HalfTest {

@Test
fun division() {
assertTrue((Half(2.0f) / Half.NaN).isNaN())
assertTrue((Half.NaN / Half(2.0f)).isNaN())
assertTrue((HALF_TWO / Half.NaN).isNaN())
assertTrue((Half.NaN / HALF_TWO).isNaN())
assertTrue((Half.POSITIVE_INFINITY / Half.NaN).isNaN())
assertTrue((Half.NaN / Half.POSITIVE_INFINITY).isNaN())
assertTrue((Half.NEGATIVE_INFINITY / Half.NaN).isNaN())
Expand All @@ -608,17 +619,17 @@ class HalfTest {
assertTrue((Half.NEGATIVE_ZERO / Half.NaN).isNaN())
assertTrue((Half.NaN / Half.NEGATIVE_ZERO).isNaN())

assertTrue((Half(2.0f) / Half.POSITIVE_INFINITY).isZero())
assertTrue((Half.POSITIVE_INFINITY / Half(2.0f)).isInfinite())
assertTrue((HALF_TWO / Half.POSITIVE_INFINITY).isZero())
assertTrue((Half.POSITIVE_INFINITY / HALF_TWO).isInfinite())

assertTrue((Half(2.0f) / Half.NEGATIVE_INFINITY).isZero())
assertTrue((Half.NEGATIVE_INFINITY / Half(2.0f)).isInfinite())
assertTrue((HALF_TWO / Half.NEGATIVE_INFINITY).isZero())
assertTrue((Half.NEGATIVE_INFINITY / HALF_TWO).isInfinite())

assertTrue((Half(2.0f) / Half.POSITIVE_ZERO).isInfinite())
assertTrue((Half.POSITIVE_ZERO / Half(2.0f)).isZero())
assertTrue((HALF_TWO / Half.POSITIVE_ZERO).isInfinite())
assertTrue((Half.POSITIVE_ZERO / HALF_TWO).isZero())

assertTrue((Half(2.0f) / Half.NEGATIVE_ZERO).isInfinite())
assertTrue((Half.NEGATIVE_ZERO / Half(2.0f)).isZero())
assertTrue((HALF_TWO / Half.NEGATIVE_ZERO).isInfinite())
assertTrue((Half.NEGATIVE_ZERO / HALF_TWO).isZero())

// Underflow
assertEquals(Half.POSITIVE_ZERO, Half.MIN_VALUE / Half.MAX_VALUE)
Expand All @@ -628,7 +639,7 @@ class HalfTest {
assertEquals(Half.POSITIVE_INFINITY, Half.MAX_VALUE / Half.MIN_VALUE)
assertEquals(Half.NEGATIVE_INFINITY, (-Half.MAX_VALUE) / Half.MIN_VALUE)

assertEquals(Half(0.5f), Half(2.0f) / Half(4.0f))
assertEquals(Half(0.5f), HALF_TWO / Half(4.0f))
assertEquals(Half(0.5f), Half(1.2f) / Half(2.4f))
assertEquals(Half(-0.5f), Half(1.2f) / Half(-2.4f))
assertEquals(Half(-0.5f), Half(-1.2f) / Half(2.4f))
Expand All @@ -637,10 +648,86 @@ class HalfTest {
assertEquals(Half(16_000.0f), Half(48_000.0f) / Half(3.0f))
assertEquals(Half(-16_000.0f), Half(48_000.0f) / Half(-3.0f))

assertEquals(Half(2.0861626e-5f), Half(1.0f) / Half(48_000.0f))
assertEquals(Half(-2.0861626e-5), Half(1.0f) / Half(-48_000.0f))
assertEquals(Half(2.0861626e-5f), HALF_ONE / Half(48_000.0f))
assertEquals(Half(-2.0861626e-5), HALF_ONE / Half(-48_000.0f))

assertEquals(Half(75.0f), Half(0.03f) / Half(0.0004f))
assertEquals(Half(-75.0f), Half(0.03f) / Half(-0.0004f))
}

@Test
fun addition() {
assertTrue((Half.NaN + Half.NaN).isNaN())

assertTrue((Half.NaN + HALF_ONE).isNaN())
assertTrue((Half.NaN - HALF_ONE).isNaN())
assertTrue((HALF_ONE + Half.NaN).isNaN())
assertTrue((Half(-1.0f) + Half.NaN).isNaN())

assertTrue((Half.NaN + Half.POSITIVE_INFINITY).isNaN())
assertTrue((Half.POSITIVE_INFINITY + Half.NaN).isNaN())
assertTrue((Half.NaN + Half.NEGATIVE_INFINITY).isNaN())
assertTrue((Half.NEGATIVE_INFINITY + Half.NaN).isNaN())

assertTrue((Half.NaN + Half.POSITIVE_ZERO).isNaN())
assertTrue((Half.POSITIVE_ZERO + Half.NaN).isNaN())
assertTrue((Half.NaN + Half.NEGATIVE_ZERO).isNaN())
assertTrue((Half.NEGATIVE_ZERO + Half.NaN).isNaN())

assertTrue((Half.POSITIVE_INFINITY + HALF_ONE).isInfinite())
assertTrue((Half.POSITIVE_INFINITY - HALF_ONE).isInfinite())
assertTrue((HALF_ONE + Half.POSITIVE_INFINITY).isInfinite())
assertTrue((HALF_ONE - Half.POSITIVE_INFINITY).isInfinite())
assertTrue((Half.POSITIVE_INFINITY + Half.POSITIVE_INFINITY).isInfinite())
assertTrue((Half.POSITIVE_INFINITY - Half.POSITIVE_INFINITY).isNaN())

assertTrue((Half.NEGATIVE_INFINITY - HALF_ONE).isInfinite())
assertTrue((Half.NEGATIVE_INFINITY + HALF_ONE).isInfinite())
assertTrue((HALF_ONE + Half.NEGATIVE_INFINITY).isInfinite())
assertTrue((HALF_ONE - Half.NEGATIVE_INFINITY).isInfinite())
assertTrue((Half.NEGATIVE_INFINITY + Half.NEGATIVE_INFINITY).isInfinite())
assertTrue((Half.NEGATIVE_INFINITY - Half.NEGATIVE_INFINITY).isNaN())

assertEquals(Half(3.0f), HALF_ONE + HALF_TWO)

// Overflow
assertEquals(Half.POSITIVE_INFINITY, Half(32768.0f) + Half(32768.0f))
// Underflow
assertEquals(Half.NEGATIVE_INFINITY, Half(-32768.0f) - Half(32768.0f))

for (i in 0x0..0xffff) {
val v1 = Half(i.toUShort())
if (v1.isFinite()) {
assertTrue((v1 - v1).isZero())
assertEquals(v1 * HALF_TWO, v1 + v1)
}
}
}

@Test
fun ulp() {
assertTrue(Half.NaN.ulp.isNaN())

assertTrue(Half.POSITIVE_INFINITY.ulp.isInfinite())
assertTrue(Half.NEGATIVE_INFINITY.ulp.isInfinite())

assertTrue((Half.MAX_VALUE + Half.MAX_VALUE.ulp).isInfinite())

assertEquals(Half.MIN_VALUE, Half.POSITIVE_ZERO.ulp)
assertEquals(Half.MIN_VALUE, Half.NEGATIVE_ZERO.ulp)

assertEquals(HALF_ONE, Half(1024.0f).ulp)
assertEquals(HALF_ONE, Half(-1024.0f).ulp)
}

@Test
fun sqrt() {
for (i in 0x0..0xffff) {
val v1 = Half(i.toUShort())
if (v1.isFinite()) {
val v2 = sqrt(v1)
assertTrue(v1 - (v2 * v2) <= HALF_TWO * v1.ulp)
}
}
}
}

0 comments on commit 9ae803a

Please sign in to comment.