From 99ff0eaa8cb00d216367512d8276c98b60a826dd Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Sun, 21 Jan 2024 16:42:57 -0500 Subject: [PATCH 1/6] throw errors when only finite weights are supported --- src/sampling.jl | 6 ++++++ src/scalarstats.jl | 2 ++ src/weights.jl | 1 + 3 files changed, 9 insertions(+) diff --git a/src/sampling.jl b/src/sampling.jl index 610dde99b..e1020f365 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -590,6 +590,7 @@ Optionally specify a random number generator `rng` as the first argument function sample(rng::AbstractRNG, wv::AbstractWeights) 1 == firstindex(wv) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) t = rand(rng) * sum(wv) n = length(wv) i = 1 @@ -714,6 +715,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) @@ -752,6 +754,7 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) k = length(x) @@ -798,6 +801,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) @@ -839,6 +843,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) @@ -912,6 +917,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) diff --git a/src/scalarstats.jl b/src/scalarstats.jl index 83b664d5f..16f02ae0f 100644 --- a/src/scalarstats.jl +++ b/src/scalarstats.jl @@ -163,6 +163,7 @@ end # Weighted mode of arbitrary vectors of values function mode(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real isempty(a) && throw(ArgumentError("mode is not defined for empty collections")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) length(a) == length(wv) || throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))")) @@ -184,6 +185,7 @@ end function modes(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real isempty(a) && throw(ArgumentError("mode is not defined for empty collections")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) length(a) == length(wv) || throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))")) diff --git a/src/weights.jl b/src/weights.jl index f5f515104..ec415d3d9 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -716,6 +716,7 @@ function quantile(v::AbstractVector{<:Real}{V}, w::AbstractWeights{W}, p::Abstra # checks isempty(v) && throw(ArgumentError("quantile of an empty array is undefined")) isempty(p) && throw(ArgumentError("empty quantile array")) + isfinite(sum(w)) || throw(ArgumentError("only finite weights are supported")) all(x -> 0 <= x <= 1, p) || throw(ArgumentError("input probability out of [0,1] range")) w.sum == 0 && throw(ArgumentError("weight vector cannot sum to zero")) From c1355a83374b60f9a91ec81973d9c0bbf76f6eaa Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Thu, 15 Feb 2024 17:20:56 -0500 Subject: [PATCH 2/6] remove extra calls to sum(wv) --- src/sampling.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index e1020f365..63c80bb42 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -590,8 +590,9 @@ Optionally specify a random number generator `rng` as the first argument function sample(rng::AbstractRNG, wv::AbstractWeights) 1 == firstindex(wv) || throw(ArgumentError("non 1-based arrays are not supported")) - isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) - t = rand(rng) * sum(wv) + wsum = sum(wv) + isfinite(wsum) || throw(ArgumentError("only finite weights are supported")) + t = rand(rng) * wsum n = length(wv) i = 1 cw = wv[1] @@ -715,19 +716,20 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) - isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) + wsum = sum(wv) + isfinite(wsum) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) # create alias table ap = Vector{Float64}(undef, n) alias = Vector{Int}(undef, n) - make_alias_table!(wv, sum(wv), ap, alias) + make_alias_table!(wv, s, ap, alias) # sampling s = Sampler(rng, 1:n) for i = 1:length(x) - j = rand(rng, s) + j = rand(rng, wsum) x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]] end return x @@ -754,14 +756,14 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) - isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) + wsum = sum(wv) + isfinite(s) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) k = length(x) w = Vector{Float64}(undef, n) copyto!(w, wv) - wsum = sum(wv) for i = 1:k u = rand(rng) * wsum From 8105d72f063e72441b18f070f2acaa772c6014d2 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Thu, 15 Feb 2024 17:23:01 -0500 Subject: [PATCH 3/6] typo --- src/sampling.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index 63c80bb42..941c1882d 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -724,7 +724,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, # create alias table ap = Vector{Float64}(undef, n) alias = Vector{Int}(undef, n) - make_alias_table!(wv, s, ap, alias) + make_alias_table!(wv, wsum, ap, alias) # sampling s = Sampler(rng, 1:n) @@ -757,7 +757,7 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) wsum = sum(wv) - isfinite(s) || throw(ArgumentError("only finite weights are supported")) + isfinite(wsum) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) k = length(x) From dafd32b9d14b342f4344174530c12ff679e41619 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Thu, 15 Feb 2024 17:23:50 -0500 Subject: [PATCH 4/6] typo --- src/sampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index 941c1882d..64e7843c2 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -729,7 +729,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, # sampling s = Sampler(rng, 1:n) for i = 1:length(x) - j = rand(rng, wsum) + j = rand(rng, s) x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]] end return x From 588ac75ee4c363c05ad57fc736e81074ea418dc0 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Thu, 16 May 2024 08:22:35 -0400 Subject: [PATCH 5/6] add a minimal test for custom weights implementations --- test/weights.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/weights.jl b/test/weights.jl index 2180c88a4..133fe8a88 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -610,4 +610,17 @@ end end end +@testset "custom weight types" begin + struct MyWeights <: AbstractWeights{Float64, Float64, Vector{Float64}} + values::Vector{Float64} + sum::Float64 + end + MyWeights(values) = MyWeights(values, sum(values)) + + @test mean([1, 2, 3], MyWeights([1, 4, 10])) ≈ 2.6 + @test mean([1, 2, 3], MyWeights([NaN, 4, 10])) |> isnan + @test mode([1, 2, 3], MyWeights([1, 4, 10])) == 3 + @test_throws ArgumentError mode([1, 2, 3], MyWeights([NaN, 4, 10])) +end + end # @testset StatsBase.Weights From 8ca1d97d17003ab02d0bf3307783c54c9d4c02e0 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Thu, 16 May 2024 08:26:12 -0400 Subject: [PATCH 6/6] fix new test on 1.0 --- test/weights.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/test/weights.jl b/test/weights.jl index 133fe8a88..76277e02f 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -1,6 +1,15 @@ using StatsBase using LinearAlgebra, Random, SparseArrays, Test + +# minimal custom weights type for tests below +struct MyWeights <: AbstractWeights{Float64, Float64, Vector{Float64}} + values::Vector{Float64} + sum::Float64 +end +MyWeights(values) = MyWeights(values, sum(values)) + + @testset "StatsBase.Weights" begin weight_funcs = (weights, aweights, fweights, pweights) @@ -611,12 +620,6 @@ end end @testset "custom weight types" begin - struct MyWeights <: AbstractWeights{Float64, Float64, Vector{Float64}} - values::Vector{Float64} - sum::Float64 - end - MyWeights(values) = MyWeights(values, sum(values)) - @test mean([1, 2, 3], MyWeights([1, 4, 10])) ≈ 2.6 @test mean([1, 2, 3], MyWeights([NaN, 4, 10])) |> isnan @test mode([1, 2, 3], MyWeights([1, 4, 10])) == 3