Skip to content

Commit

Permalink
Merge pull request #51 from olivierlabayle/fix_adapt
Browse files Browse the repository at this point in the history
update adapt
  • Loading branch information
olivierlabayle authored Mar 19, 2022
2 parents b375237 + 7094b21 commit 687e543
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ Base.merge(ndt₁::AbstractNode, ndt₂::AbstractNode) =
Adapts the type of the treatment variable passed to the G learner
"""
adapt(T::NamedTuple{<:Any, NTuple{1, Z}}) where Z = T[1]
adapt(T) = T
adapt(T) =
size(Tables.columnnames(T), 1) == 1 ? Tables.getcolumn(T, 1) : T

adapt(T::AbstractNode) = node(adapt, T)


Expand Down
6 changes: 3 additions & 3 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ end
@testset "Test multiple targets with missing data" begin
n = 100
rng = StableRNG(123)
T = (t=categorical(rand(rng, [0, 1], n)),)
W = Tables.table(rand(rng, n, 2))
T = Tables.table(categorical(rand(rng, [0, 1], n)))
W = (w=rand(rng, n),)
Y = (
y₁ = vcat(rand(rng, n-10), repeat([missing], 10)),
y₂ = vcat(repeat([missing], 20), rand(rng, n-20))
)
query = Query((t=0,), (t=1,))
query = Query((Column1=0,), (Column1=1,))
= MLJ.DeterministicConstantRegressor()
G = ConstantClassifier()

Expand Down

2 comments on commit 687e543

@olivierlabayle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.6.1 already exists and points to a different commit"

Please sign in to comment.