Skip to content

Commit

Permalink
Context
Browse files Browse the repository at this point in the history
This commit introduces `Context`, a structure that holds configuration
of the decimal arithmetics.

Eventually, the global variable `DIGITS` should be completely removed in
favor of this newly-added structure.
  • Loading branch information
barucden committed Nov 1, 2024
1 parent a789905 commit 465c69b
Show file tree
Hide file tree
Showing 13 changed files with 1,821 additions and 893 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.8'
- '1'
os:
- ubuntu-latest
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ name = "Decimals"
uuid = "abce61dc-4473-55a0-ba07-351d65e31d42"
version = "0.4.1"

[deps]
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"

[compat]
julia = "1"
ScopedValues = "1"
julia = "1.8"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
148 changes: 91 additions & 57 deletions scripts/dectest.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
function _precision(line)
m = match(r"^precision:\s*(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _rounding(line)
m = match(r"^rounding:\s*(\w+)$", line)
return Symbol(m[1])
isnothing(m) && throw(ArgumentError(line))
r = m[1]
if r == "ceiling"
return "RoundUp"
elseif r == "down"
return "RoundToZero"
elseif r == "floor"
return "RoundDown"
elseif r == "half_even"
return "RoundNearest"
elseif r == "half_up"
return "RoundNearestTiesAway"
elseif r == "up"
return "RoundFromZero"
elseif r == "half_down"
return "RoundHalfDownUnsupported"
elseif r == "05up"
return "Round05UpUnsupported"
else
throw(ArgumentError(r))
end
end

function _maxexponent(line)
m = match(r"^maxexponent:\s*(\d+)$", line)
m = match(r"^maxexponent:\s*\+?(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _minexponent(line)
m = match(r"^minexponent:\s*(-\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _test(line)
occursin("->", line) || throw(ArgumentError(line))
lhs, rhs = split(line, "->")
id, operation, operands... = split(lhs)
result, conditions... = split(rhs)
Expand All @@ -31,47 +55,55 @@ function decimal(x)
return "dec\"$x\""
end

print_precision(io, p::Int) = println(io, " setprecision(Decimal, $p)")
print_maxexponent(io, e::Int) = println(io, " Decimals.CONTEXT.Emax = $e")
print_minexponent(io, e::Int) = println(io, " Decimals.CONTEXT.Emin = $e")
function print_rounding(io, r::Symbol)
modes = Dict(:ceiling => "RoundUp",
:down => "RoundToZero",
:floor => "RoundDown",
:half_even => "RoundNearest",
:half_up => "RoundNearestTiesAway",
:up => "RoundFromZero",
:half_down => "RoundHalfDownUnsupported",
Symbol("05up") => "Round05UpUnsupported")
haskey(modes, r) || throw(ArgumentError(r))
rmod = modes[r]
println(io, " setrounding(Decimal, $rmod)")
end

function print_operation(io, operation, operands)
if operation == "plus"
print_plus(io, operands...)
elseif operation == "minus"
print_minus(io, operands...)
if operation == "abs"
print_abs(io, operands...)
elseif operation == "add"
print_add(io, operands...)
elseif operation == "apply"
print_apply(io, operands...)
elseif operation == "compare"
print_compare(io, operands...)
elseif operation == "divide"
print_divide(io, operands...)
elseif operation == "minus"
print_minus(io, operands...)
elseif operation == "multiply"
print_multiply(io, operands...)
elseif operation == "plus"
print_plus(io, operands...)
elseif operation == "reduce"
print_reduce(io, operands...)
elseif operation == "subtract"
print_subtract(io, operands...)
else
throw(ArgumentError(operation))
end
end
print_abs(io, x) = print(io, "abs(", decimal(x), ")")
print_add(io, x, y) = print(io, decimal(x), " + ", decimal(y))
print_apply(io, x) = print(io, decimal(x))
print_compare(io, x, y) = print(io, "cmp(", decimal(x), ", ", decimal(y), ")")
print_divide(io, x, y) = print(io, decimal(x), " / ", decimal(y))
print_minus(io, x) = print(io, "-(", decimal(x), ")")
print_multiply(io, x, y) = print(io, decimal(x), " * ", decimal(y))
print_plus(io, x) = print(io, "+(", decimal(x), ")")
print_reduce(io, x) = print(io, "reduce(", decimal(x), ")")
print_subtract(io, x, y) = print(io, decimal(x), " - ", decimal(y))

function print_test(io, test)
function print_test(io, test, directives)
println(io, " # $(test.id)")

names = sort!(collect(keys(directives)))
params = join(("$k=$(directives[k])" for k in names), ", ")
print(io, " @with_context ($params) ")

if :overflow test.conditions
print(io, " @test_throws OverflowError ")
print(io, "@test_throws OverflowError ")
print_operation(io, test.operation, test.operands)
println(io)
else
print(io, " @test ")
print(io, "@test ")
print_operation(io, test.operation, test.operands)
print(io, " == ")
println(io, decimal(test.result))
Expand All @@ -83,34 +115,36 @@ function isspecial(value)
return occursin(r"(inf|nan|#)", value)
end

function translate(io, line)
isempty(line) && return
startswith(line, "--") && return

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
precision = _precision(line)
print_precision(io, precision)
elseif startswith(line, "rounding:")
rounding = _rounding(line)
print_rounding(io, rounding)
elseif startswith(line, "maxexponent:")
maxexponent = _maxexponent(line)
print_maxexponent(io, maxexponent)
elseif startswith(line, "minexponent:")
minexponent = _minexponent(line)
print_minexponent(io, minexponent)
else
test = _test(line)
any(isspecial, test.operands) && return
print_test(io, test)
function translate(io, dectest_path)
directives = Dict{String, Any}()

for line in eachline(dectest_path)
line = strip(line)

isempty(line) && continue
startswith(line, "--") && continue

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
directives["precision"] = _precision(line)
elseif startswith(line, "rounding:")
directives["rounding"] = _rounding(line)
elseif startswith(line, "maxexponent:")
directives["Emax"] = _maxexponent(line)
elseif startswith(line, "minexponent:")
directives["Emin"] = _minexponent(line)
else
test = _test(line)
any(isspecial, test.operands) && continue
print_test(io, test, directives)
end
end
end

Expand All @@ -120,13 +154,13 @@ function (@main)(args=ARGS)
open(output_path, "w") do io
println(io, """
using Decimals
using ScopedValues
using Test
using Decimals: @with_context
@testset \"$name\" begin""")

for line in eachline(dectest_path)
translate(io, line)
end
translate(io, dectest_path)

println(io, "end")
end
Expand Down
1 change: 1 addition & 0 deletions src/Decimals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct Decimal <: AbstractFloat
end

include("bigint.jl")
include("context.jl")

# Convert between Decimal objects, numbers, and strings
include("decimal.jl")
Expand Down
6 changes: 2 additions & 4 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Base.promote_rule(::Type{Decimal}, ::Type{<:Real}) = Decimal
Base.promote_rule(::Type{BigFloat}, ::Type{Decimal}) = Decimal
Base.promote_rule(::Type{BigInt}, ::Type{Decimal}) = Decimal

const BigTen = BigInt(10)
Base.:(+)(x::Decimal) = fix(x)
Base.:(-)(x::Decimal) = fix(Decimal(!x.s, x.c, x.q))

# Addition
# To add, convert both decimals to the same exponent.
Expand All @@ -24,9 +25,6 @@ function Base.:(+)(x::Decimal, y::Decimal)
return normalize(Decimal(s, abs(c), y.q))
end

# Negation
Base.:(-)(x::Decimal) = Decimal(!x.s, x.c, x.q)

# Subtraction
Base.:(-)(x::Decimal, y::Decimal) = +(x, -y)

Expand Down
2 changes: 2 additions & 0 deletions src/bigint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ else
const libgmp = Base.GMP.libgmp
end

const BigTen = BigInt(10)

function isdivisible(x::BigInt, n::Int)
r = ccall((:__gmpz_divisible_ui_p, libgmp), Cint,
(Base.GMP.MPZ.mpz_t, Culong), x, n)
Expand Down
Loading

0 comments on commit 465c69b

Please sign in to comment.