From db2a1e6426aca815893fc29cf6ab5fa90e0e772b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 24 Jul 2023 17:57:06 +0100 Subject: [PATCH] fix issue in merging when CSV parses as other type than String --- src/merge.jl | 15 ++++++++++----- src/sieve_variance.jl | 14 ++++++++++++++ src/utils.jl | 4 ++-- test/sieve_variance.jl | 9 ++++++++- test/tmle.jl | 2 +- 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/merge.jl b/src/merge.jl index 405a8c7..3c34649 100644 --- a/src/merge.jl +++ b/src/merge.jl @@ -9,10 +9,15 @@ function files_matching_prefix_and_suffix(prefix, suffix) return [joinpath(dirname_, x) for x in files] end -function load_csv_files(files) - data = DataFrame() +read_output_with_types(file) = + CSV.read(file, DataFrame, types=Dict(key => String for key in joining_keys())) + +function load_csv_files(data, files) for file in files - data = vcat(data, CSV.read(file, DataFrame)) + new_data = read_output_with_types(file) + if size(new_data, 1) > 0 + data = vcat(data, new_data) + end end return data end @@ -25,7 +30,7 @@ function merge_csv_files(parsed_args) ".csv" ) # Load tmle data - data = load_csv_files(tmle_files) + data = load_csv_files(empty_tmle_output(), tmle_files) # Load sieve data sieveprefix = parsed_args["sieve-prefix"] if sieveprefix !== nothing @@ -33,7 +38,7 @@ function merge_csv_files(parsed_args) parsed_args["sieve-prefix"], ".csv" ) - sieve_data = load_csv_files(sieve_files) + sieve_data = load_csv_files(empty_sieve_output(), sieve_files) if size(sieve_data, 1) > 0 data = leftjoin(data, sieve_data, on=joining_keys(), matchmissing=:equal) end diff --git a/src/sieve_variance.jl b/src/sieve_variance.jl index fcef1af..27c4c04 100644 --- a/src/sieve_variance.jl +++ b/src/sieve_variance.jl @@ -34,6 +34,20 @@ sieve_dataframe() = DataFrame( TMLE_ESTIMATE=Float64[], ) +empty_sieve_output() = DataFrame( + PARAMETER_TYPE=String[], + TREATMENTS=String[], + CASE=String[], + CONTROL=Union{String, Missing}[], + TARGET=String[], + CONFOUNDERS=String[], + COVARIATES=Union{String, Missing}[], + SIEVE_STD = Float64[], + SIEVE_PVALUE = Float64[], + SIEVE_LWB = Float64[], + SIEVE_UPB = Float64[], +) + function push_sieveless!(output, Ψ, Ψ̂) target = string(Ψ.target) param_type = param_string(Ψ) diff --git a/src/utils.jl b/src/utils.jl index 556c879..9ca8dea 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ ##################################################################### -csv_headers(;size=0) = DataFrame( +empty_tmle_output(;size=0) = DataFrame( PARAMETER_TYPE=Vector{String}(undef, size), TREATMENTS=Vector{String}(undef, size), CASE=Vector{String}(undef, size), @@ -75,7 +75,7 @@ statistics_from_result(result::MissingTMLEResult) = (missing, missing, missing, missing, missing) function append_csv(filename, tmle_results, logs) - data = csv_headers(size=size(tmle_results, 1)) + data = empty_tmle_output(size=size(tmle_results, 1)) for (i, (result, log)) in enumerate(zip(tmle_results, logs)) Ψ = result.parameter param_type = param_string(Ψ) diff --git a/test/sieve_variance.jl b/test/sieve_variance.jl index 370d2e5..b90153d 100644 --- a/test/sieve_variance.jl +++ b/test/sieve_variance.jl @@ -313,7 +313,7 @@ end @test size(io["variances"]) == (10, 8) close(io) # check csv file - output = CSV.read(string(outprefix, ".csv"), DataFrame) + output = TargetedEstimation.read_output_with_types(string(outprefix, ".csv")) some_expected_cols = DataFrame( PARAMETER_TYPE = ["IATE", "IATE", "ATE", "IATE", "IATE", "ATE", "ATE", "CM"], TREATMENTS = ["T1_&_T2", "T1_&_T2", "T1_&_T2", "T1_&_T2", "T1_&_T2", "T1_&_T2", "T1", "T1"], @@ -329,6 +329,13 @@ end @test output.SIEVE_UPB isa Vector{Float64} @test output.SIEVE_STD isa Vector{Float64} + tmle_output = TargetedEstimation.load_csv_files( + TargetedEstimation.empty_tmle_output(), + ["tmle_output_1.csv", "tmle_output_2.csv"] + ) + + joined = leftjoin(tmle_output, output, on=TargetedEstimation.joining_keys(), matchmissing=:equal) + @test all(joined.SIEVE_PVALUE .> 0 ) # clean rm(string(outprefix, ".csv")) rm(string(outprefix, ".hdf5")) diff --git a/test/tmle.jl b/test/tmle.jl index cfb9fdf..281041b 100644 --- a/test/tmle.jl +++ b/test/tmle.jl @@ -174,7 +174,7 @@ end ## Check CSV file data = CSV.read(parsed_args["csv-out"], DataFrame) - @test names(TargetedEstimation.csv_headers()) == names(data) + @test names(TargetedEstimation.empty_tmle_output()) == names(data) @test size(data) == (6, 19) all(x === missing for x in data.LOG) # Clean