Skip to content

Commit

Permalink
fix issue in merging when CSV parses as other type than String
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Jul 24, 2023
1 parent b306a48 commit db2a1e6
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
15 changes: 10 additions & 5 deletions src/merge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ function files_matching_prefix_and_suffix(prefix, suffix)
return [joinpath(dirname_, x) for x in files]
end

function load_csv_files(files)
data = DataFrame()
read_output_with_types(file) =
CSV.read(file, DataFrame, types=Dict(key => String for key in joining_keys()))

function load_csv_files(data, files)
for file in files
data = vcat(data, CSV.read(file, DataFrame))
new_data = read_output_with_types(file)
if size(new_data, 1) > 0
data = vcat(data, new_data)
end
end
return data
end
Expand All @@ -25,15 +30,15 @@ function merge_csv_files(parsed_args)
".csv"
)
# Load tmle data
data = load_csv_files(tmle_files)
data = load_csv_files(empty_tmle_output(), tmle_files)
# Load sieve data
sieveprefix = parsed_args["sieve-prefix"]
if sieveprefix !== nothing
sieve_files = files_matching_prefix_and_suffix(
parsed_args["sieve-prefix"],
".csv"
)
sieve_data = load_csv_files(sieve_files)
sieve_data = load_csv_files(empty_sieve_output(), sieve_files)
if size(sieve_data, 1) > 0
data = leftjoin(data, sieve_data, on=joining_keys(), matchmissing=:equal)
end
Expand Down
14 changes: 14 additions & 0 deletions src/sieve_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ sieve_dataframe() = DataFrame(
TMLE_ESTIMATE=Float64[],
)

empty_sieve_output() = DataFrame(
PARAMETER_TYPE=String[],
TREATMENTS=String[],
CASE=String[],
CONTROL=Union{String, Missing}[],
TARGET=String[],
CONFOUNDERS=String[],
COVARIATES=Union{String, Missing}[],
SIEVE_STD = Float64[],
SIEVE_PVALUE = Float64[],
SIEVE_LWB = Float64[],
SIEVE_UPB = Float64[],
)

function push_sieveless!(output, Ψ, Ψ̂)
target = string.target)
param_type = param_string(Ψ)
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#####################################################################


csv_headers(;size=0) = DataFrame(
empty_tmle_output(;size=0) = DataFrame(
PARAMETER_TYPE=Vector{String}(undef, size),
TREATMENTS=Vector{String}(undef, size),
CASE=Vector{String}(undef, size),
Expand Down Expand Up @@ -75,7 +75,7 @@ statistics_from_result(result::MissingTMLEResult) =
(missing, missing, missing, missing, missing)

function append_csv(filename, tmle_results, logs)
data = csv_headers(size=size(tmle_results, 1))
data = empty_tmle_output(size=size(tmle_results, 1))
for (i, (result, log)) in enumerate(zip(tmle_results, logs))
Ψ = result.parameter
param_type = param_string(Ψ)
Expand Down
9 changes: 8 additions & 1 deletion test/sieve_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ end
@test size(io["variances"]) == (10, 8)
close(io)
# check csv file
output = CSV.read(string(outprefix, ".csv"), DataFrame)
output = TargetedEstimation.read_output_with_types(string(outprefix, ".csv"))
some_expected_cols = DataFrame(
PARAMETER_TYPE = ["IATE", "IATE", "ATE", "IATE", "IATE", "ATE", "ATE", "CM"],
TREATMENTS = ["T1_&_T2", "T1_&_T2", "T1_&_T2", "T1_&_T2", "T1_&_T2", "T1_&_T2", "T1", "T1"],
Expand All @@ -329,6 +329,13 @@ end
@test output.SIEVE_UPB isa Vector{Float64}
@test output.SIEVE_STD isa Vector{Float64}

tmle_output = TargetedEstimation.load_csv_files(
TargetedEstimation.empty_tmle_output(),
["tmle_output_1.csv", "tmle_output_2.csv"]
)

joined = leftjoin(tmle_output, output, on=TargetedEstimation.joining_keys(), matchmissing=:equal)
@test all(joined.SIEVE_PVALUE .> 0 )
# clean
rm(string(outprefix, ".csv"))
rm(string(outprefix, ".hdf5"))
Expand Down
2 changes: 1 addition & 1 deletion test/tmle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ end

## Check CSV file
data = CSV.read(parsed_args["csv-out"], DataFrame)
@test names(TargetedEstimation.csv_headers()) == names(data)
@test names(TargetedEstimation.empty_tmle_output()) == names(data)
@test size(data) == (6, 19)
all(x === missing for x in data.LOG)
# Clean
Expand Down

0 comments on commit db2a1e6

Please sign in to comment.