Skip to content

Commit

Permalink
Implement Base.cmp
Browse files Browse the repository at this point in the history
This operation is useful for implementing `:(==)`, `:(<)`, and `:(<=)`.
The more efficient `cmp` is, the more efficient these comparisons are.
  • Loading branch information
barucden committed Oct 25, 2024
1 parent 29a80b4 commit 6eee6cd
Show file tree
Hide file tree
Showing 3 changed files with 897 additions and 28 deletions.
96 changes: 68 additions & 28 deletions src/equals.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,85 @@
# Equality
_sign(x::Decimal) = x.s ? -1 : 1

# equals() now depends on == instead
# of the other way round.
function Base.:(==)(x::Decimal, y::Decimal)
# return early on zero
x_is_zero = iszero(x)
y_is_zero = iszero(y)
if x_is_zero || y_is_zero
return x_is_zero === y_is_zero
function Base.cmp(x::Decimal, y::Decimal)
if iszero(x) && iszero(y)
return 0
elseif iszero(x) # && !iszero(y)
return -_sign(y)
elseif iszero(y) # && !iszero(x)
return _sign(x)
end

a = normalize(x)
b = normalize(y)
a.c == b.c && a.q == b.q && a.s == b.s
end
# Neither x nor y is zero here

function Base.:(<)(x::Decimal, y::Decimal)
# return early on zero
if iszero(x) && iszero(y)
return false
if x.s != y.s
# x and y have different signs, so
# if x < 0, then return -1 (because y is positive)
# if x > 0, then return +1 (because y is negative)
return _sign(x)
end

# avoid normalization if possible
if x.q == y.q
return isless(x.s == 0 ? x.c : -x.c, y.s == 0 ? y.c : -y.c)
cmp_c = cmp(x.c, y.c)
cmp_q = cmp(x.q, y.q)

# If both x.c and x.q is greater (or equal, or less) than y.c and y.q,
# then x is greater (or equal, or less) than y
if cmp_c == cmp_q
return cmp_c
end

diff = y - x
# Adjusted exponent of x and y
# It is the position of the most significant digit with respect to
# the decimal point
expx = ndigits(x.c) + x.q - 1
expy = ndigits(y.c) + y.q - 1

farther_from_0 = diff.c > 0 || (iszero(diff.c) && diff.q > 0)
# If expx > expy, then abs(x) > abs(y)
# If expx < expy, then abs(x) < abs(y)
#
# Then we need to consider the sign, which is the same for x and y here
#
# Overall:
# -1 if expx > expy and they are negative
# +1 if expx > expy and they are positive
# -1 if expx < expy and they are positive
# +1 if expx < expy and they are negative
if expx != expy
s = _sign(x) # same as _sign(y)
return ifelse(expx > expy, s, -s)
end

if diff.s == 1
return !farther_from_0
# cmp(x, y) = sign(x - y)
# = sign(sign(x) * abs(x) - sign(y) * abs(y))
#
# We know that x and y have the same sign here:
#
# cmp(x, y) = sign(sign(x) * (abs(x) - abs(y)))
# = sign(x) * sign(abs(x) - abs(y))
# = sign(x) * sign(x.c * 10^x.q - y.c * 10^y.q)
#
# Now, for the latter sign:
#
# sign(x.c * 10^x.q - y.c * 10^y.q)
# = sign(x.c * 10^(x.q - y.q) - y.c) * 10^y.q
# = sign(x.c - y.c * 10^(y.q - x.q)) * 10^x.q
# ^^^^^^ positive
#
# So, we just need to return
#
# sign(x) * sign(x.c * 10^(x.q - y.q) - y.c) if x.q ≥ y.q,
# sign(x) * sign(x.c - y.c * 10^(y.q - x.q)) if x.q < y.q
if x.q y.q
q = x.q - y.q
return _sign(x) * cmp(x.c * big(10) ^ q, y.c)
else
return farther_from_0
q = y.q - x.q
return _sign(x) * cmp(x.c, y.c * big(10) ^ q)
end
end

function Base.:(<=)(x::Decimal, y::Decimal)
return x < y || x == y
end
Base.:(==)(x::Decimal, y::Decimal) = iszero(cmp(x, y))
Base.:(<)(x::Decimal, y::Decimal) = cmp(x, y) < 0
Base.:(<=)(x::Decimal, y::Decimal) = cmp(x, y) <= 0

# Special case equality with AbstractFloat to allow comparison against Inf/Nan
# which are not representable in Decimal
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ include("test_hash.jl")
include("test_norm.jl")
include("test_round.jl")

include("test_compare.jl")

end
Loading

0 comments on commit 6eee6cd

Please sign in to comment.