Skip to content

Commit

Permalink
relax metrics to Real (#74)
Browse files Browse the repository at this point in the history
* format files

* relax metrics to Real and test it

* fix missing ||
  • Loading branch information
KristofferC authored Aug 1, 2017
1 parent e26201a commit 7db5360
Show file tree
Hide file tree
Showing 9 changed files with 507 additions and 448 deletions.
1 change: 0 additions & 1 deletion src/Distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ export
rogerstanimoto,
chebyshev,
minkowski,
mahalanobis,

hamming,
cosine_dist,
Expand Down
18 changes: 9 additions & 9 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function get_colwise_dims(d::Int, r::AbstractArray, a::AbstractVector, b::Abstra
end

function get_colwise_dims(d::Int, r::AbstractArray, a::AbstractMatrix, b::AbstractVector)
size(a, 1) == length(b) == d
size(a, 1) == length(b) == d ||
throw(DimensionMismatch("Incorrect vector dimensions."))
length(r) == size(a, 2) || throw(DimensionMismatch("Incorrect size of r."))
return size(a)
Expand Down Expand Up @@ -109,10 +109,10 @@ function sumsq_percol(a::AbstractMatrix{T}) where {T}
return r
end

function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1,T2}
function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1, T2}
m = size(a, 1)
n = size(a, 2)
T = typeof(one(T1)*one(T2))
T = typeof(one(T1) * one(T2))
r = Vector{T}(n)
for j = 1:n
aj = view(a, :, j)
Expand All @@ -126,16 +126,16 @@ function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1,T2
end

function dot_percol!(r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix)
m = size(a,1)
n = size(a,2)
size(b) == (m,n) && length(r) == n ||
m = size(a, 1)
n = size(a, 2)
size(b) == (m, n) && length(r) == n ||
throw(DimensionMismatch("Inconsistent array dimensions."))
for j = 1:n
aj = view(a,:,j)
bj = view(b,:,j)
aj = view(a, :, j)
bj = view(b, :, j)
r[j] = dot(aj, bj)
end
return r
end

dot_percol(a::AbstractMatrix, b::AbstractMatrix) = dot_percol!(Vector{Float64}(size(a,2)), a, b)
dot_percol(a::AbstractMatrix, b::AbstractMatrix) = dot_percol!(Vector{Float64}(size(a, 2)), a, b)
28 changes: 14 additions & 14 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ result_type(::PreMetric, ::AbstractArray, ::AbstractArray) = Float64
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::AbstractMatrix)
n = size(b, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
for j = 1 : n
for j = 1:n
@inbounds r[j] = evaluate(metric, a, view(b, :, j))
end
r
Expand All @@ -41,7 +41,7 @@ end
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractVector)
n = size(a, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
for j = 1 : n
for j = 1:n
@inbounds r[j] = evaluate(metric, view(a, :, j), b)
end
r
Expand All @@ -50,7 +50,7 @@ end
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
n = get_common_ncols(a, b)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
for j = 1 : n
for j = 1:n
@inbounds r[j] = evaluate(metric, view(a, :, j), view(b, :, j))
end
r
Expand Down Expand Up @@ -85,10 +85,10 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::A
na = size(a, 2)
nb = size(b, 2)
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
for j = 1 : size(b, 2)
bj = view(b,:,j)
for i = 1 : size(a, 2)
@inbounds r[i,j] = evaluate(metric, view(a,:,i), bj)
for j = 1:size(b, 2)
bj = view(b, :, j)
for i = 1:size(a, 2)
@inbounds r[i, j] = evaluate(metric, view(a, :, i), bj)
end
end
r
Expand All @@ -101,14 +101,14 @@ end
function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
n = size(a, 2)
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
for j = 1 : n
aj = view(a,:,j)
for i = j+1 : n
@inbounds r[i,j] = evaluate(metric, view(a,:,i), aj)
for j = 1:n
aj = view(a, :, j)
for i = (j + 1):n
@inbounds r[i, j] = evaluate(metric, view(a, :, i), aj)
end
@inbounds r[j,j] = 0
for i = 1 : j-1
@inbounds r[i,j] = r[j,i] # leveraging the symmetry of SemiMetric
@inbounds r[j, j] = 0
for i = 1:(j - 1)
@inbounds r[i, j] = r[j, i] # leveraging the symmetry of SemiMetric
end
end
r
Expand Down
38 changes: 19 additions & 19 deletions src/mahalanobis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ result_type(::SqMahalanobis{T}, ::AbstractArray, ::AbstractArray) where {T} = T

# SqMahalanobis

function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: AbstractFloat}
function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -25,22 +25,22 @@ end

sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(SqMahalanobis(Q), a, b)

function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, n = get_colwise_dims(size(Q, 1), r, a, b)
z = a - b
dot_percol!(r, Q * z, z)
end

function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: AbstractFloat}
function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, n = get_colwise_dims(size(Q, 1), r, a, b)
z = a .- b
Qz = Q * z
dot_percol!(r, Q * z, z)
end

function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)

Expand All @@ -50,29 +50,29 @@ function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix,
sb2 = dot_percol(b, Qb)
At_mul_B!(r, a, Qb)

for j = 1 : nb
@simd for i = 1 : na
@inbounds r[i,j] = sa2[i] + sb2[j] - 2 * r[i,j]
for j = 1:nb
@simd for i = 1:na
@inbounds r[i, j] = sa2[i] + sb2[j] - 2 * r[i, j]
end
end
r
end

function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix) where {T <: AbstractFloat}
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, n = get_pairwise_dims(size(Q, 1), r, a)

Qa = Q * a
sa2 = dot_percol(a, Qa)
At_mul_B!(r, a, Qa)

for j = 1 : n
for i = 1 : j-1
@inbounds r[i,j] = r[j,i]
for j = 1:n
for i = 1:(j - 1)
@inbounds r[i, j] = r[j, i]
end
r[j,j] = 0
for i = j+1 : n
@inbounds r[i,j] = sa2[i] + sa2[j] - 2 * r[i,j]
r[j, j] = 0
for i = (j + 1):n
@inbounds r[i, j] = sa2[i] + sa2[j] - 2 * r[i, j]
end
end
r
Expand All @@ -81,24 +81,24 @@ end

# Mahalanobis

function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: AbstractFloat}
function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
sqrt(evaluate(SqMahalanobis(dist.qmat), a, b))
end

mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(Mahalanobis(Q), a, b)

function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: AbstractFloat}
function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: Real}
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix) where {T <: AbstractFloat}
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix) where {T <: Real}
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a))
end
Loading

0 comments on commit 7db5360

Please sign in to comment.