Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute cohen's kappa from confusion matrix #111

Open
ericphanson opened this issue Apr 30, 2024 · 0 comments
Open

compute cohen's kappa from confusion matrix #111

ericphanson opened this issue Apr 30, 2024 · 0 comments

Comments

@ericphanson
Copy link
Member

currently, we compute many metrics directly from the confusion matrix, but not Cohen's kappa:

"""
cohens_kappa(class_count, hard_label_pairs)
Return `(κ, p₀)` where `κ` is Cohen's kappa and `p₀` percent agreement given
`class_count` and `hard_label_pairs` (these arguments take the same form as
their equivalents in [`confusion_matrix`](@ref)).
"""
function cohens_kappa(class_count, hard_label_pairs)
all(issubset(pair, 1:class_count) for pair in hard_label_pairs) ||
throw(ArgumentError("Unexpected class in `hard_label_pairs`."))
p₀ = accuracy(confusion_matrix(class_count, hard_label_pairs))
pₑ = _probability_of_chance_agreement(class_count, hard_label_pairs)
return _cohens_kappa(p₀, pₑ), p₀
end
_cohens_kappa(p₀, pₑ) = (p₀ - pₑ) / (1 - ifelse(pₑ == 1, zero(pₑ), pₑ))
function _probability_of_chance_agreement(class_count, hard_label_pairs)
labels_1 = (pair[1] for pair in hard_label_pairs)
labels_2 = (pair[2] for pair in hard_label_pairs)
x = sum(k -> count(==(k), labels_1) * count(==(k), labels_2), 1:class_count)
return x / length(hard_label_pairs)^2
end

It could clean up some of the data flow to do so. Here is an implementation.

function cohens_kappa_from_confusion_matrix(conf)
    p₀ = accuracy(conf)
    pₑ = probability_of_chance_agreement_from_confusion_matrix(conf)
    return (p₀ - pₑ) / (1 - ifelse(pₑ == 1, zero(pₑ), pₑ))
end

function probability_of_chance_agreement_from_confusion_matrix(conf)
    counts_1 = dropdims(sum(conf; dims=1); dims=1)
    counts_2 = dropdims(sum(conf; dims=2); dims=2)
    n = sum(counts_1)
    @check n == sum(counts_2)
    return dot(counts_1, counts_2) / n^2
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant