Skip to content

Commit

Permalink
Merge pull request #66 from olivierlabayle/machine_cache
Browse files Browse the repository at this point in the history
add no machine cache as a default
  • Loading branch information
olivierlabayle authored Dec 20, 2022
2 parents 923dc15 + 640166f commit 935bdc2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ Main entrypoint to run the TMLE procedure.
- verbosity: The logging level
- threshold: To avoid small values of Ĝ to cause the "clever covariate" to explode
"""
function tmle::Parameter, η_spec::NuisanceSpec, dataset; verbosity=1, threshold=1e-8)
function tmle::Parameter, η_spec::NuisanceSpec, dataset; verbosity=1, threshold=1e-8, mach_cache=false)
cache = TMLECache(Ψ, η_spec, dataset)
return tmle!(cache; verbosity=verbosity, threshold=threshold)
return tmle!(cache; verbosity=verbosity, threshold=threshold, mach_cache=mach_cache)
end

function tmle!(cache; verbosity=1, threshold=1e-8)
function tmle!(cache; verbosity=1, threshold=1e-8, mach_cache=false)
Ψ, η_spec, dataset, η = cache.Ψ, cache.η_spec, cache.dataset, cache.η
# Initial fit
verbosity >= 1 && @info "Fitting the nuisance parameters..."
TMLE.fit!(η, η_spec, Ψ, dataset, verbosity=verbosity)
TMLE.fit!(η, η_spec, Ψ, dataset, verbosity=verbosity, mach_cache=mach_cache)

# Estimation results before TMLE
dataset = dataset[:no_missing]
Expand All @@ -98,7 +98,7 @@ function tmle!(cache; verbosity=1, threshold=1e-8)

# TMLE step
verbosity >= 1 && @info "Targeting the nuisance parameters..."
tmle!(η, Ψ, η_spec, dataset, verbosity=verbosity, threshold=threshold)
tmle!(η, Ψ, η_spec, dataset, verbosity=verbosity, threshold=threshold, mach_cache=mach_cache)

# Estimation results after TMLE
IC = gradient(Ψ, η, dataset; threshold=threshold)
Expand Down
12 changes: 6 additions & 6 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,15 @@ Q_model(t::Type{Any}) = throw(ArgumentError("Cannot proceed with Q model with ta
Fits the nuisance parameters η on the dataset using the specifications from η_spec
and the variables defined by Ψ.
"""
function fit!::NuisanceParameters, η_spec::NuisanceSpec, Ψ::Parameter, dataset; verbosity=1)
function fit!::NuisanceParameters, η_spec::NuisanceSpec, Ψ::Parameter, dataset; verbosity=1, mach_cache=false)
# Fitting P(T|W)
# Only rows with missing values in either W or Tₜ are removed
if η.G === nothing
log_fit(verbosity, "P(T|W)")
nomissing_WT = nomissing(dataset[:source], treatment_and_confounders(Ψ))
W = confounders(nomissing_WT, Ψ)
T = treatments(nomissing_WT, Ψ)
mach = machine(adapt(η_spec.G, T), W, adapt(T))
mach = machine(adapt(η_spec.G, T), W, adapt(T), cache=mach_cache)
MLJBase.fit!(mach, verbosity=verbosity-1)
η.G = mach
else
Expand All @@ -248,14 +248,14 @@ function fit!(η::NuisanceParameters, η_spec::NuisanceSpec, Ψ::Parameter, data
# Fitting the Encoder
if η.H === nothing
log_fit(verbosity, "Encoder")
mach = machine(η_spec.H, X)
mach = machine(η_spec.H, X, cache=mach_cache)
MLJBase.fit!(mach, verbosity=verbosity-1)
η.H = mach
else
log_no_fit(verbosity, "Encoder")
end
log_fit(verbosity, "E[Y|X]")
mach = machine(η_spec.Q, MLJBase.transform.H, X), y)
mach = machine(η_spec.Q, MLJBase.transform.H, X), y, cache=mach_cache)
MLJBase.fit!(mach, verbosity=verbosity-1)
η.Q = mach
else
Expand All @@ -276,10 +276,10 @@ function fluctuation_input(dataset, η, Ψ; threshold=1e-8)
return fluctuation_input(covariate, offset)
end

function tmle!::NuisanceParameters, Ψ, η_spec, dataset; verbosity=1, threshold=1e-8)
function tmle!::NuisanceParameters, Ψ, η_spec, dataset; verbosity=1, threshold=1e-8, mach_cache=false)
X = fluctuation_input(dataset, η, Ψ, threshold=threshold)
y = target(dataset, Ψ)
mach = machine(η_spec.F, X, y)
mach = machine(η_spec.F, X, y, cache=mach_cache)
MLJBase.fit!(mach, verbosity=verbosity-1)
η.F = mach
end
Expand Down

0 comments on commit 935bdc2

Please sign in to comment.