Skip to content

Commit

Permalink
Merge pull request #27 from TARGENE/fix_coerce
Browse files Browse the repository at this point in the history
fix issue with count type outcome
  • Loading branch information
olivierlabayle authored May 3, 2024
2 parents 4b78fab + 625e099 commit fdfa3a6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 18 deletions.
28 changes: 24 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,34 @@ function make_float!(dataset, colnames)
end
end

function coerce_types!(dataset, colnames)
infered_types = autotype(dataset[!, colnames])
function coerce_types!(dataset, colnames; rules=:few_to_finite)
infered_types = autotype(dataset[!, colnames], rules)
coerce!(dataset, infered_types)
end

coerce_types!(dataset, Ψ::TMLE.Estimand) =
coerce_types!(dataset, collect(variables(Ψ)))
"""
Outcomes and Treatment variables must be dealt with differently until there are models dealing specificaly with Count data.
- Outcomes need to be binary, i.e. OrderedFactor{2} or Continuous
- Treatments need to be Categorical
- Other variables can be dealt with either way
"""
function coerce_types!(dataset, Ψ::TMLE.Estimand)
all_outcomes = outcomes(Ψ)
for outcome in all_outcomes
if isbinary(outcome, dataset)
coerce_types!(dataset, [outcome], rules=:few_to_finite)
else
coerce_types!(dataset, [outcome], rules=:discrete_to_continuous)
end
end
other_variables = collect(setdiff(variables(Ψ), all_outcomes))
coerce_types!(dataset, other_variables, rules=:few_to_finite)
end

outcomes::TMLE.Estimand) = Set([Ψ.outcome])

outcomes::TMLE.ComposedEstimand) = union((outcomes(arg) for arg in Ψ.args)...)

variables::TMLE.ComposedEstimand) = union((variables(arg) for arg in Ψ.args)...)

Expand Down
1 change: 1 addition & 0 deletions test/sieve_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ function write_sieve_dataset(sample_ids)

dataset[!, "CONTINUOUS, OUTCOME"] = y₁
dataset[!, "BINARY/OUTCOME"] = categorical(y₂)
dataset[!, "COUNT_OUTCOME"] = rand(rng, [1, 2, 3, 4], n)

CSV.write("data.csv", dataset)
end
Expand Down
4 changes: 2 additions & 2 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function statistical_estimands_only_config()
outcome_extra_covariates = (:C1,)
),
CM(
outcome = Symbol("CONTINUOUS, OUTCOME"),
outcome = Symbol("COUNT_OUTCOME"),
treatment_values = (
T1 = true,
T2 = false),
Expand Down Expand Up @@ -124,7 +124,7 @@ function build_dataset(;n=1000, format="csv")
dataset[!, "CONTINUOUS, OUTCOME"] = y₁
# Slash in name
dataset[!, "BINARY/OUTCOME"] = y₂
dataset[!, "EXTREME_BINARY"] = vcat(0, ones(n-1))
dataset[!, "COUNT_OUTCOME"] = rand(rng, [1, 2, 3, 4], n)

return dataset
end
Expand Down
33 changes: 21 additions & 12 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,48 @@ end
)
end
@testset "Test coerce_types!" begin
Ψ = IATE(
outcome=:Ycont,
treatment_values=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")),
treatment_confounders=(T₁=[:W₁, :W₂], T₂=[:W₁, :W₂]),
)

dataset = DataFrame(
Ycont = [1.1, 2.2, missing, 3.5, 6.6, 0., 4.],
Ycat = [1., 0., missing, 1., 0, 0, 0],
Ybin = [1., 0., missing, 1., 0, 0, 0],
Ycount = [1, 0., missing, 1, 2, 0, 3],
T₁ = [1, 0, missing, 0, 0, 0, missing],
T₂ = [missing, "AC", "CC", "CC", missing, "AA", "AA"],
W₁ = [1., 0., 0., 1., 0., 1, 1],
W₂ = [missing, 0., 0., 0., 0., 0., 0.],
C = [1, 2, 3, 4, 5, 6, 6]
)
# Continuous Outcome
Ψ = IATE(
outcome=:Ycont,
treatment_values=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")),
treatment_confounders=(T₁=[:W₁, :W₂], T₂=[:W₁, :W₂]),
)
TargetedEstimation.coerce_types!(dataset, Ψ)

@test scitype(dataset.T₁) == AbstractVector{Union{Missing, OrderedFactor{2}}}
@test scitype(dataset.T₂) == AbstractVector{Union{Missing, Multiclass{3}}}
@test scitype(dataset.Ycont) == AbstractVector{Union{Missing, MLJBase.Continuous}}
@test scitype(dataset.W₁) == AbstractVector{OrderedFactor{2}}
@test scitype(dataset.W₂) == AbstractVector{Union{Missing, OrderedFactor{1}}}


# Binary Outcome
Ψ = IATE(
outcome=:Ycat,
outcome=:Ybin,
treatment_values=(T₂=(case="AC", control="CC"), ),
treatment_confounders=(T₂=[:W₂],),
outcome_extra_covariates=[:C]
)
TargetedEstimation.coerce_types!(dataset, Ψ)

@test scitype(dataset.Ycat) == AbstractVector{Union{Missing, OrderedFactor{2}}}
@test scitype(dataset.Ybin) == AbstractVector{Union{Missing, OrderedFactor{2}}}
@test scitype(dataset.C) == AbstractVector{Count}

# Count Outcome
Ψ = IATE(
outcome=:Ycount,
treatment_values=(T₂=(case="AC", control="CC"), ),
treatment_confounders=(T₂=[:W₂],),
)
TargetedEstimation.coerce_types!(dataset, Ψ)
@test scitype(dataset.Ycount) == AbstractVector{Union{Missing, MLJBase.Continuous}}
end

@testset "Test misc" begin
Expand Down

0 comments on commit fdfa3a6

Please sign in to comment.