diff --git a/src/utils.jl b/src/utils.jl index 071d8e2..740e006 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -152,14 +152,34 @@ function make_float!(dataset, colnames) end end -function coerce_types!(dataset, colnames) - infered_types = autotype(dataset[!, colnames]) +function coerce_types!(dataset, colnames; rules=:few_to_finite) + infered_types = autotype(dataset[!, colnames], rules) coerce!(dataset, infered_types) end -coerce_types!(dataset, Ψ::TMLE.Estimand) = - coerce_types!(dataset, collect(variables(Ψ))) +""" +Outcomes and Treatment variables must be dealt with differently until there are models dealing specificaly with Count data. + +- Outcomes need to be binary, i.e. OrderedFactor{2} or Continuous +- Treatments need to be Categorical +- Other variables can be dealt with either way +""" +function coerce_types!(dataset, Ψ::TMLE.Estimand) + all_outcomes = outcomes(Ψ) + for outcome in all_outcomes + if isbinary(outcome, dataset) + coerce_types!(dataset, [outcome], rules=:few_to_finite) + else + coerce_types!(dataset, [outcome], rules=:discrete_to_continuous) + end + end + other_variables = collect(setdiff(variables(Ψ), all_outcomes)) + coerce_types!(dataset, other_variables, rules=:few_to_finite) +end + +outcomes(Ψ::TMLE.Estimand) = Set([Ψ.outcome]) +outcomes(Ψ::TMLE.ComposedEstimand) = union((outcomes(arg) for arg in Ψ.args)...) variables(Ψ::TMLE.ComposedEstimand) = union((variables(arg) for arg in Ψ.args)...) diff --git a/test/sieve_variance.jl b/test/sieve_variance.jl index 5bf29a2..ef47781 100644 --- a/test/sieve_variance.jl +++ b/test/sieve_variance.jl @@ -44,6 +44,7 @@ function write_sieve_dataset(sample_ids) dataset[!, "CONTINUOUS, OUTCOME"] = y₁ dataset[!, "BINARY/OUTCOME"] = categorical(y₂) + dataset[!, "COUNT_OUTCOME"] = rand(rng, [1, 2, 3, 4], n) CSV.write("data.csv", dataset) end diff --git a/test/testutils.jl b/test/testutils.jl index db9f450..969d7cd 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -51,7 +51,7 @@ function statistical_estimands_only_config() outcome_extra_covariates = (:C1,) ), CM( - outcome = Symbol("CONTINUOUS, OUTCOME"), + outcome = Symbol("COUNT_OUTCOME"), treatment_values = ( T1 = true, T2 = false), @@ -124,7 +124,7 @@ function build_dataset(;n=1000, format="csv") dataset[!, "CONTINUOUS, OUTCOME"] = y₁ # Slash in name dataset[!, "BINARY/OUTCOME"] = y₂ - dataset[!, "EXTREME_BINARY"] = vcat(0, ones(n-1)) + dataset[!, "COUNT_OUTCOME"] = rand(rng, [1, 2, 3, 4], n) return dataset end diff --git a/test/utils.jl b/test/utils.jl index 9a31afb..517154e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -68,39 +68,48 @@ end ) end @testset "Test coerce_types!" begin - Ψ = IATE( - outcome=:Ycont, - treatment_values=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), - treatment_confounders=(T₁=[:W₁, :W₂], T₂=[:W₁, :W₂]), - ) - dataset = DataFrame( Ycont = [1.1, 2.2, missing, 3.5, 6.6, 0., 4.], - Ycat = [1., 0., missing, 1., 0, 0, 0], + Ybin = [1., 0., missing, 1., 0, 0, 0], + Ycount = [1, 0., missing, 1, 2, 0, 3], T₁ = [1, 0, missing, 0, 0, 0, missing], T₂ = [missing, "AC", "CC", "CC", missing, "AA", "AA"], W₁ = [1., 0., 0., 1., 0., 1, 1], W₂ = [missing, 0., 0., 0., 0., 0., 0.], C = [1, 2, 3, 4, 5, 6, 6] ) + # Continuous Outcome + Ψ = IATE( + outcome=:Ycont, + treatment_values=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), + treatment_confounders=(T₁=[:W₁, :W₂], T₂=[:W₁, :W₂]), + ) TargetedEstimation.coerce_types!(dataset, Ψ) - @test scitype(dataset.T₁) == AbstractVector{Union{Missing, OrderedFactor{2}}} @test scitype(dataset.T₂) == AbstractVector{Union{Missing, Multiclass{3}}} @test scitype(dataset.Ycont) == AbstractVector{Union{Missing, MLJBase.Continuous}} @test scitype(dataset.W₁) == AbstractVector{OrderedFactor{2}} @test scitype(dataset.W₂) == AbstractVector{Union{Missing, OrderedFactor{1}}} - + + # Binary Outcome Ψ = IATE( - outcome=:Ycat, + outcome=:Ybin, treatment_values=(T₂=(case="AC", control="CC"), ), treatment_confounders=(T₂=[:W₂],), outcome_extra_covariates=[:C] ) TargetedEstimation.coerce_types!(dataset, Ψ) - - @test scitype(dataset.Ycat) == AbstractVector{Union{Missing, OrderedFactor{2}}} + @test scitype(dataset.Ybin) == AbstractVector{Union{Missing, OrderedFactor{2}}} @test scitype(dataset.C) == AbstractVector{Count} + + # Count Outcome + Ψ = IATE( + outcome=:Ycount, + treatment_values=(T₂=(case="AC", control="CC"), ), + treatment_confounders=(T₂=[:W₂],), + ) + TargetedEstimation.coerce_types!(dataset, Ψ) + @test scitype(dataset.Ycount) == AbstractVector{Union{Missing, MLJBase.Continuous}} end @testset "Test misc" begin