From 422b4ede900ce49f8d19e80440dc55beab6ab038 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Denis=20Baru=C4=8Di=C4=87?= Date: Wed, 16 Oct 2024 08:57:50 +0200 Subject: [PATCH] Implement `Base.cmp` This operation is useful for implementing `:(==)`, `:(<)`, and `:(<=)`. The more efficient `cmp` is, the more efficient these comparisons are. --- src/equals.jl | 77 +++++++++++++++++++++++++++------------------ test/test_equals.jl | 28 +++++++++++++++++ 2 files changed, 75 insertions(+), 30 deletions(-) diff --git a/src/equals.jl b/src/equals.jl index e1b5367..779a0be 100644 --- a/src/equals.jl +++ b/src/equals.jl @@ -1,47 +1,64 @@ -# Equality +_sign(x::BigInt) = Int(sign(x)) +_sign(x::Decimal) = x.s ? -1 : 1 -# equals() now depends on == instead -# of the other way round. -function Base.:(==)(x::Decimal, y::Decimal) - # return early on zero - x_is_zero = iszero(x) - y_is_zero = iszero(y) - if x_is_zero || y_is_zero - return x_is_zero === y_is_zero - end +function Base.cmp(x::Decimal, y::Decimal) + # We try to avoid computing x - y because it allocates a new BigInt - a = normalize(x) - b = normalize(y) - a.c == b.c && a.q == b.q && a.s == b.s -end + if iszero(x) && iszero(y) + return 0 + elseif iszero(x) # && !iszero(y) + return -_sign(y) + elseif iszero(y) # && !iszero(x) + return _sign(x) + end -Base.iszero(x::Decimal) = iszero(x.c) + # Neither x nor y is zero here -function Base.:(<)(x::Decimal, y::Decimal) - # return early on zero - if iszero(x) && iszero(y) - return false + if x.s != y.s + # x and y have different signs, so + # if x < 0, then return -1 (because y is positive) + # if x > 0, then return +1 (because y is negative) + return _sign(x) end - # avoid normalization if possible - if x.q == y.q - return isless(x.s == 0 ? x.c : -x.c, y.s == 0 ? y.c : -y.c) + cmp_c = cmp(x.c, y.c) + cmp_q = cmp(x.q, y.q) + + # If both x.c and x.q is greater (or equal, or less) than y.c and y.q, + # then x is greater (or equal, or less) than y + if cmp_c == cmp_q + return cmp_c end - diff = y - x + # Let x = a * 10^p, y = b * 10^q. + # + # If p ≥ q: + # + # sign(x - y) + # = sign(a * 10^p - b * 10^q) + # = sign((a * 10^(p - q) - b) * 10^q) + # = sign(a * 10^(p - q) - b) + # + # If p < q: + # + # sign(x - y) + # = a - b * 10^(q - p) - farther_from_0 = diff.c > 0 || (iszero(diff.c) && diff.q > 0) + xcoef = (-1)^x.s * x.c + ycoef = (-1)^y.s * y.c - if diff.s == 1 - return !farther_from_0 + if x.q ≥ y.q + q = x.q - y.q + return _sign(xcoef * big(10) ^ q - ycoef) else - return farther_from_0 + q = y.q - x.q + return _sign(xcoef - ycoef * big(10) ^ q) end end -function Base.:(<=)(x::Decimal, y::Decimal) - return x < y || x == y -end +Base.:(==)(x::Decimal, y::Decimal) = iszero(cmp(x, y)) +Base.:(<)(x::Decimal, y::Decimal) = cmp(x, y) < 0 +Base.:(<=)(x::Decimal, y::Decimal) = cmp(x, y) <= 0 # Special case equality with AbstractFloat to allow comparison against Inf/Nan # which are not representable in Decimal diff --git a/test/test_equals.jl b/test/test_equals.jl index 07c6529..cef2601 100644 --- a/test/test_equals.jl +++ b/test/test_equals.jl @@ -3,6 +3,34 @@ using Test @testset "Equality" begin +@testset "cmp" begin + @test cmp(parse(Decimal, "-2" ), parse(Decimal, "-2")) == 0 + @test cmp(parse(Decimal, "-2" ), parse(Decimal, "-1")) == -1 + @test cmp(parse(Decimal, "-2" ), parse(Decimal, "0")) == -1 + @test cmp(parse(Decimal, "-2" ), parse(Decimal, "1")) == -1 + @test cmp(parse(Decimal, "-2" ), parse(Decimal, "2")) == -1 + @test cmp(parse(Decimal, "-1" ), parse(Decimal, "-2")) == 1 + @test cmp(parse(Decimal, "-1" ), parse(Decimal, "-1")) == 0 + @test cmp(parse(Decimal, "-1" ), parse(Decimal, "0")) == -1 + @test cmp(parse(Decimal, "-1" ), parse(Decimal, "1")) == -1 + @test cmp(parse(Decimal, "-1" ), parse(Decimal, "2")) == -1 + @test cmp(parse(Decimal, "0" ), parse(Decimal, "-2")) == 1 + @test cmp(parse(Decimal, "0" ), parse(Decimal, "-1")) == 1 + @test cmp(parse(Decimal, "0" ), parse(Decimal, "0")) == 0 + @test cmp(parse(Decimal, "0" ), parse(Decimal, "1")) == -1 + @test cmp(parse(Decimal, "0" ), parse(Decimal, "2")) == -1 + @test cmp(parse(Decimal, "1" ), parse(Decimal, "-2")) == 1 + @test cmp(parse(Decimal, "1" ), parse(Decimal, "-1")) == 1 + @test cmp(parse(Decimal, "1" ), parse(Decimal, "0")) == 1 + @test cmp(parse(Decimal, "1" ), parse(Decimal, "1")) == 0 + @test cmp(parse(Decimal, "1" ), parse(Decimal, "2")) == -1 + @test cmp(parse(Decimal, "2" ), parse(Decimal, "-2")) == 1 + @test cmp(parse(Decimal, "2" ), parse(Decimal, "-1")) == 1 + @test cmp(parse(Decimal, "2" ), parse(Decimal, "0")) == 1 + @test cmp(parse(Decimal, "2" ), parse(Decimal, "1")) == 1 + @test cmp(parse(Decimal, "2" ), parse(Decimal, "2")) == 0 +end + @testset "isequal" begin @test isequal(Decimal(false, 2, -3), Decimal(false, 2, -3)) @test !isequal(Decimal(false, 2, -3), Decimal(false, 2, 3))