Skip to content

Commit

Permalink
Implement Base.cmp
Browse files Browse the repository at this point in the history
This operation is useful for implementing `:(==)`, `:(<)`, and `:(<=)`.
The more efficient `cmp` is, the more efficient these comparisons are.
  • Loading branch information
barucden committed Oct 25, 2024
2 parents 29a80b4 + 422b4ed commit 7890160
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 28 deletions.
75 changes: 47 additions & 28 deletions src/equals.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,64 @@
# Equality
_sign(x::BigInt) = Int(sign(x))
_sign(x::Decimal) = x.s ? -1 : 1

# equals() now depends on == instead
# of the other way round.
function Base.:(==)(x::Decimal, y::Decimal)
# return early on zero
x_is_zero = iszero(x)
y_is_zero = iszero(y)
if x_is_zero || y_is_zero
return x_is_zero === y_is_zero
function Base.cmp(x::Decimal, y::Decimal)
# We try to avoid computing x - y because it allocates a new BigInt

if iszero(x) && iszero(y)
return 0
elseif iszero(x) # && !iszero(y)
return -_sign(y)
elseif iszero(y) # && !iszero(x)
return _sign(x)
end

a = normalize(x)
b = normalize(y)
a.c == b.c && a.q == b.q && a.s == b.s
end
# Neither x nor y is zero here

function Base.:(<)(x::Decimal, y::Decimal)
# return early on zero
if iszero(x) && iszero(y)
return false
if x.s != y.s
# x and y have different signs, so
# if x < 0, then return -1 (because y is positive)
# if x > 0, then return +1 (because y is negative)
return _sign(x)
end

# avoid normalization if possible
if x.q == y.q
return isless(x.s == 0 ? x.c : -x.c, y.s == 0 ? y.c : -y.c)
cmp_c = cmp(x.c, y.c)
cmp_q = cmp(x.q, y.q)

# If both x.c and x.q is greater (or equal, or less) than y.c and y.q,
# then x is greater (or equal, or less) than y
if cmp_c == cmp_q
return cmp_c
end

diff = y - x
# Let x = a * 10^p, y = b * 10^q.
#
# If p ≥ q:
#
# sign(x - y)
# = sign(a * 10^p - b * 10^q)
# = sign((a * 10^(p - q) - b) * 10^q)
# = sign(a * 10^(p - q) - b)
#
# If p < q:
#
# sign(x - y)
# = a - b * 10^(q - p)

farther_from_0 = diff.c > 0 || (iszero(diff.c) && diff.q > 0)
xcoef = (-1)^x.s * x.c
ycoef = (-1)^y.s * y.c

if diff.s == 1
return !farther_from_0
if x.q y.q
q = x.q - y.q
return _sign(xcoef * big(10) ^ q - ycoef)
else
return farther_from_0
q = y.q - x.q
return _sign(xcoef - ycoef * big(10) ^ q)
end
end

function Base.:(<=)(x::Decimal, y::Decimal)
return x < y || x == y
end
Base.:(==)(x::Decimal, y::Decimal) = iszero(cmp(x, y))
Base.:(<)(x::Decimal, y::Decimal) = cmp(x, y) < 0
Base.:(<=)(x::Decimal, y::Decimal) = cmp(x, y) <= 0

# Special case equality with AbstractFloat to allow comparison against Inf/Nan
# which are not representable in Decimal
Expand Down
28 changes: 28 additions & 0 deletions test/test_equals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,34 @@ using Test

@testset "Equality" begin

@testset "cmp" begin
@test cmp(parse(Decimal, "-2" ), parse(Decimal, "-2")) == 0
@test cmp(parse(Decimal, "-2" ), parse(Decimal, "-1")) == -1
@test cmp(parse(Decimal, "-2" ), parse(Decimal, "0")) == -1
@test cmp(parse(Decimal, "-2" ), parse(Decimal, "1")) == -1
@test cmp(parse(Decimal, "-2" ), parse(Decimal, "2")) == -1
@test cmp(parse(Decimal, "-1" ), parse(Decimal, "-2")) == 1
@test cmp(parse(Decimal, "-1" ), parse(Decimal, "-1")) == 0
@test cmp(parse(Decimal, "-1" ), parse(Decimal, "0")) == -1
@test cmp(parse(Decimal, "-1" ), parse(Decimal, "1")) == -1
@test cmp(parse(Decimal, "-1" ), parse(Decimal, "2")) == -1
@test cmp(parse(Decimal, "0" ), parse(Decimal, "-2")) == 1
@test cmp(parse(Decimal, "0" ), parse(Decimal, "-1")) == 1
@test cmp(parse(Decimal, "0" ), parse(Decimal, "0")) == 0
@test cmp(parse(Decimal, "0" ), parse(Decimal, "1")) == -1
@test cmp(parse(Decimal, "0" ), parse(Decimal, "2")) == -1
@test cmp(parse(Decimal, "1" ), parse(Decimal, "-2")) == 1
@test cmp(parse(Decimal, "1" ), parse(Decimal, "-1")) == 1
@test cmp(parse(Decimal, "1" ), parse(Decimal, "0")) == 1
@test cmp(parse(Decimal, "1" ), parse(Decimal, "1")) == 0
@test cmp(parse(Decimal, "1" ), parse(Decimal, "2")) == -1
@test cmp(parse(Decimal, "2" ), parse(Decimal, "-2")) == 1
@test cmp(parse(Decimal, "2" ), parse(Decimal, "-1")) == 1
@test cmp(parse(Decimal, "2" ), parse(Decimal, "0")) == 1
@test cmp(parse(Decimal, "2" ), parse(Decimal, "1")) == 1
@test cmp(parse(Decimal, "2" ), parse(Decimal, "2")) == 0
end

@testset "isequal" begin
@test isequal(Decimal(false, 2, -3), Decimal(false, 2, -3))
@test !isequal(Decimal(false, 2, -3), Decimal(false, 2, 3))
Expand Down

0 comments on commit 7890160

Please sign in to comment.