Skip to content

Commit

Permalink
Merge pull request #48 from olivierlabayle/input_management
Browse files Browse the repository at this point in the history
Input management
  • Loading branch information
olivierlabayle authored Mar 16, 2022
2 parents df498ce + 87677f4 commit e524036
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

Expand All @@ -18,6 +19,7 @@ HypothesisTests = "0.10"
MLJBase = "0.19"
MLJGLMInterface = "0.1, 0.2"
MLJModels = "0.15"
TableOperations = "1.2"
Tables = "1.5"
julia = "1.6, 1.7"
TableOperations = "1.2"
Missings = "1.0"
2 changes: 1 addition & 1 deletion src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using HypothesisTests
using Base: Iterators, ImmutableDict
using MLJGLMInterface
using MLJModels

using Missings

# #############################################################################
# OVERLOADED METHODS
Expand Down
7 changes: 2 additions & 5 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,11 @@ function MLJBase.fit(tmle::TMLEstimator,
Ws = source(W)
Ys = source(Y)

# Filtering missing values before G fit
T, W = TableOperations.dropmissing(Ts, Ws)

# Fitting the encoder
Hmach = machine(OneHotEncoder(drop_last=true), T)
Hmach = machine(OneHotEncoder(drop_last=true), Ts)

# Fitting P(T|W)
Gmach = machine(tmle.G, W, adapt(T))
Gmach = machine(tmle.G, Ws, adapt(Ts))

reported = []
predicted = []
Expand Down
20 changes: 14 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ totable(x::AbstractVector) = (y=x,)
totable(x) = x

function merge_and_dropmissing(tables::Vararg)
return mapreduce(t->Tables.columntable(t), merge, tables) |>
TableOperations.dropmissing |> Tables.columntable

return mapreduce(t->Tables.columntable(t), merge, tables) |>
TableOperations.dropmissing |>
Tables.columntable |>
disallowmissings
end

TableOperations.select(t::AbstractNode, columns...) =
Expand All @@ -51,14 +52,21 @@ TableOperations.select(t::AbstractNode, columns::AbstractNode) =
Tables.columnnames(t::AbstractNode) =
node(Tables.columnnames, t)

function disallowmissings(T)
newcols = AbstractVector[]
sch = Tables.schema(T)
Tables.eachcolumn(sch, T) do col, _, _
push!(newcols, disallowmissing(col))
end
return NamedTuple{sch.names}(newcols)
end

function TableOperations.dropmissing(tables::Vararg{AbstractNode})
table = node(merge_and_dropmissing, tables...)
table = node(disallowmissings, table)
return Tuple(TableOperations.select(table, Tables.columnnames(t)) for t in tables)
end


Base.first(y::AbstractNode) = node(Base.first, y)

###############################################################################
## Offset
###############################################################################
Expand Down
3 changes: 3 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ end
Column2 = [4, 4],
Column3 = [3, 9]
)
@test eltype(T.t₁) == eltype(T.t₂) == eltype(T.Column3) == Int

T = TMLE.disallowmissings(T)

T₁ = source(T₁)
T₂ = source(T₂)
Expand Down

0 comments on commit e524036

Please sign in to comment.