From f1e54a2ab7ce1ba4602b2220479f95239ac49eaa Mon Sep 17 00:00:00 2001 From: MariaHei Date: Mon, 10 Jun 2024 09:44:23 +0200 Subject: [PATCH] Add token-based accuracy measure Fixes #118 --- src/cholesky.jl | 4 ++-- src/eval.jl | 45 ++++++++++++++++++++++++++++++------- test/eval_tests.jl | 55 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 10 deletions(-) diff --git a/src/cholesky.jl b/src/cholesky.jl index 47b2a6fb..ec8e0564 100644 --- a/src/cholesky.jl +++ b/src/cholesky.jl @@ -327,7 +327,7 @@ end """ make_transform_matrix(X::Union{SparseMatrixCSC,Matrix}, Y::Union{SparseMatrixCSC,Matrix}, - freq::Array{Int64,1}) + freq::Union{Array{Int64, 1}, Array{Float64,1}}) Weight X and Y using the frequencies in freq. Then use the Cholesky decomposition to calculate the transformation matrix from X to Y, @@ -336,7 +336,7 @@ where X is a sparse matrix and Y is a sparse matrix. # Obligatory Arguments - `X::SparseMatrixCSC`: the X matrix, where X is a sparse matrix - `Y::SparseMatrixCSC`: the Y matrix, where Y is a sparse matrix -- `freq::Array{Int64,1}`: list of frequencies of the wordforms in X and Y +- `freq::Union{Array{Int64, 1}, Array{Float64,1}}`: list of frequencies of the wordforms in X and Y # Optional Arguments - `method::Symbol = :additive`: whether :additive or :multiplicative decomposition is required diff --git a/src/eval.jl b/src/eval.jl index 43922661..b01cf7c7 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -225,6 +225,7 @@ function accuracy_comprehension( Comp_Acc_Struct(dfr, acc, err) end + """ eval_SC(SChat::AbstractArray, SC::AbstractArray) @@ -232,6 +233,8 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. +If `freq` is added, token-based accuracy is computed. Token-based accuracy weighs accuracy values according to words' frequency, i.e. if a word has a frequency of 30 and overall there are 3000 tokens (the frequencies of all types sum to 3000), this token's accuracy will contribute 30/3000. + !!! note If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs. @@ -242,6 +245,7 @@ of the pertinent correlation matrices. # Optional Arguments - `digits`: the specified number of digits after the decimal place (or before if negative) - `R::Bool=false`: if true, pairwise correlation matrix R is return +- `freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing`: list of frequencies of the wordforms in X and Y ```julia eval_SC(Chat_train, cue_obj_train.C) @@ -250,7 +254,8 @@ eval_SC(Shat_train, S_train) eval_SC(Shat_val, S_val) ``` """ -function eval_SC(SChat::AbstractArray, SC::AbstractArray; digits=4, R=false) +function eval_SC(SChat::AbstractArray, SC::AbstractArray; digits=4, R=false, + freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing) if size(unique(SC, dims=1), 1) != size(SC, 1) @warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information." @@ -262,7 +267,12 @@ function eval_SC(SChat::AbstractArray, SC::AbstractArray; digits=4, R=false) dims = 2, ) v = [rSC[i[1], i[1]] == rSC[i] ? 1 : 0 for i in argmax(rSC, dims = 2)] - acc = round(sum(v) / length(v), digits=digits) + if !ismissing(freq) + v .*= freq + acc = round(sum(v) / sum(freq), digits=digits) + else + acc = round(sum(v) / length(v), digits=digits) + end if R return acc, rSC else @@ -277,6 +287,8 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. +If `freq` is added, token-based accuracy is computed. Token-based accuracy weighs accuracy values according to words' frequency, i.e. if a word has a frequency of 30 and overall there are 3000 tokens (the frequencies of all types sum to 3000), this token's accuracy will contribute 30/3000. + !!! note The order is important. The fist gold standard matrix has to be corresponing to the SChat matrix, such as `eval_SC(Shat_train, S_train, S_val)` or `eval_SC(Shat_val, S_val, S_train)` @@ -292,6 +304,7 @@ of the pertinent correlation matrices. # Optional Arguments - `digits`: the specified number of digits after the decimal place (or before if negative) - `R::Bool=false`: if true, pairwise correlation matrix R is return +- `freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing`: list of frequencies of the wordforms in X and Y ```julia eval_SC(Chat_train, cue_obj_train.C, cue_obj_val.C) @@ -305,12 +318,14 @@ function eval_SC( SC::AbstractArray, SC_rest::AbstractArray; digits = 4, - R = false + R = false, + freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing ) - eval_SC(SChat, vcat(SC, SC_rest); digits=digits, R=R) + eval_SC(SChat, vcat(SC, SC_rest); digits=digits, R=R, freq=freq) end + """ eval_SC(SChat::AbstractArray, SC::AbstractArray, data::DataFrame, target_col::Union{String, Symbol}) @@ -318,6 +333,8 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. Support for homophones. +If `freq` is added, token-based accuracy is computed. Token-based accuracy weighs accuracy values according to words' frequency, i.e. if a word has a frequency of 30 and overall there are 3000 tokens (the frequencies of all types sum to 3000), this token's accuracy will contribute 30/3000. + # Obligatory Arguments - `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix - `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix @@ -327,6 +344,7 @@ of the pertinent correlation matrices. Support for homophones. # Optional Arguments - `digits`: the specified number of digits after the decimal place (or before if negative) - `R::Bool=false`: if true, pairwise correlation matrix R is return +- `freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing`: list of frequencies of the wordforms in X and Y ```julia eval_SC(Chat_train, cue_obj_train.C, latin, :Word) @@ -341,7 +359,8 @@ function eval_SC( data::DataFrame, target_col::Union{String, Symbol}; digits = 4, - R = false + R = false, + freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing ) rSC = cor( @@ -353,7 +372,12 @@ function eval_SC( data[i[1], target_col] == data[i[2], target_col] ? 1 : 0 for i in argmax(rSC, dims = 2) ] - acc = round(sum(v) / length(v), digits=digits) + if !ismissing(freq) + v .*= freq + acc = round(sum(v) / sum(freq), digits=digits) + else + acc = round(sum(v) / length(v), digits=digits) + end if R return acc, rSC else @@ -368,6 +392,8 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an C or Shat and S. Ideally the target words have highest correlations on the diagonal of the pertinent correlation matrices. +If `freq` is added, token-based accuracy is computed. Token-based accuracy weighs accuracy values according to words' frequency, i.e. if a word has a frequency of 30 and overall there are 3000 tokens (the frequencies of all types sum to 3000), this token's accuracy will contribute 30/3000. + !!! note The order is important. The fist gold standard matrix has to be corresponing to the SChat matrix, such as `eval_SC(Shat_train, S_train, S_val, latin, :Word)` @@ -384,6 +410,7 @@ of the pertinent correlation matrices. # Optional Arguments - `digits`: the specified number of digits after the decimal place (or before if negative) - `R::Bool=false`: if true, pairwise correlation matrix R is return +- `freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing`: list of frequencies of the wordforms in X and Y ```julia eval_SC(Chat_train, cue_obj_train.C, cue_obj_val.C, latin, :Word) @@ -400,7 +427,8 @@ function eval_SC( data_rest::DataFrame, target_col::Union{String, Symbol}; digits = 4, - R = false + R = false, + freq::Union{Missing, Array{Int64, 1}, Array{Float64,1}}=missing ) n_data = size(data, 1) @@ -428,7 +456,8 @@ function eval_SC( data_combined, target_col, digits = digits, - R = R + R = R, + freq=freq ) end diff --git a/test/eval_tests.jl b/test/eval_tests.jl index 2538b3c5..934c847b 100644 --- a/test/eval_tests.jl +++ b/test/eval_tests.jl @@ -171,6 +171,61 @@ end end +@testset "token_accuracy" begin + S = [[1. 2. 3.] + [4. 5. 6.] + [7. 8. 9.]] + + S2 = [[1. 2. 3.] + [0. 0.1 0.1] + [9. 8. 7.]] + + @test JudiLing.eval_SC(S, S2) ≈ 0.3333 + @test JudiLing.eval_SC(S, S2, freq=[1,1,1]) ≈ 0.3333 + @test JudiLing.eval_SC(S, S2, freq=[1,2,3]) ≈ 0.1667 + @test JudiLing.eval_SC(S, S2, freq=[3,2,1]) ≈ 0.5 + @test JudiLing.eval_SC(S, S, freq=[1,2,3]) == JudiLing.eval_SC(S, S) + + S3 = [7. 8. 9.] + + @test JudiLing.eval_SC(S, S2, S3) ≈ 0.3333 + @test JudiLing.eval_SC(S, S2, S3, freq=[1,1,1]) ≈ 0.3333 + @test JudiLing.eval_SC(S, S2, S3, freq=[1,2,3]) ≈ 0.1667 + @test JudiLing.eval_SC(S, S2, S3, freq=[3,2,1]) ≈ 0.5 + @test JudiLing.eval_SC(S2, S2, S3, freq=[1,2,3]) == JudiLing.eval_SC(S2, S2) + + S4 = [[1. 2. 3.] + [6. 5. 4.] + [7. 8. 9.]] + + S5 = [[1. 2. 3.] + [9. 3. 1.] + [0. 0.1 0.1]] + + S6 = [6. 5. 4.] + + @test JudiLing.eval_SC(S4, S5) ≈ 0.6667 + @test JudiLing.eval_SC(S4, S5, S6) ≈ 0.3333 + @test JudiLing.eval_SC(S4, S5, freq=[1,1,1]) ≈ 0.6667 + @test JudiLing.eval_SC(S4, S5, S6, freq=[1,1,1]) ≈ 0.3333 + @test JudiLing.eval_SC(S4, S5, S6, freq=[1,2,3]) ≈ 0.1667 + @test JudiLing.eval_SC(S4, S5, S6, freq=[3,2,1]) ≈ 0.5 + + data = DataFrame("Word"=>["a", "b", "a"]) + @test JudiLing.eval_SC(S, S2, data, "Word") ≈ 0.6667 + @test JudiLing.eval_SC(S, S2, data, "Word", freq=[1,1,1]) ≈ 0.6667 + @test JudiLing.eval_SC(S, S2, data, "Word", freq=[1,3,2]) ≈ 0.5 + @test JudiLing.eval_SC(S, S2, data, "Word", freq=[3,1,2]) ≈ round(5/6, digits=4) + @test JudiLing.eval_SC(S2, S2, data, "Word", freq=[3,1,2]) ≈ 1.0 + + data = DataFrame("Word"=>["a", "b", "c"]) + data2 = DataFrame("Word"=>["b"]) + @test JudiLing.eval_SC(S4, S5, S6, data, data2, "Word") ≈ 0.6667 + @test JudiLing.eval_SC(S4, S5, S6, data, data2, "Word", freq=[1,1,1]) ≈ 0.6667 + @test JudiLing.eval_SC(S4, S5, S6, data, data2, "Word", freq=[3,2,1]) ≈ round(5/6, digits=4) + @test JudiLing.eval_SC(S4, S5, S6, data, data2, "Word", freq=[1,2,3]) ≈ 0.5 +end + @testset "accuracy_comprehension" begin latin = DataFrame( Word = ["ABC", "BCD", "CDE", "BCD"],