Skip to content

Commit

Permalink
Merge pull request #67 from olivierlabayle/reduce_alloc
Browse files Browse the repository at this point in the history
Fix cache
  • Loading branch information
olivierlabayle authored Jan 3, 2023
2 parents ff38797 + 9fd9188 commit c286d65
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/src/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,6 @@ nothing # hide

Since we have only updated $G$'s specification, only this model is fitted again.

### Scenario N
### General behaviour

Feel free to play around with the cache and to report any non consistent behaviour.
Any change to either the `Parameter` of interest or the `NuisanceSpec` structures will trigger an update of the cache.
29 changes: 21 additions & 8 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ mutable struct TMLECache
Ψ::Parameter
η_spec::NuisanceSpec
dataset
η
function TMLECache(Ψ, η_spec, dataset)
η::NuisanceParameters
mach_cache::Bool
function TMLECache(Ψ, η_spec, dataset, mach_cache)
dataset = Dict(
:source => dataset,
:no_missing => nomissing(dataset, allcolumns(Ψ))
)
η = NuisanceParameters(nothing, nothing, nothing, nothing)
new(Ψ, η_spec, dataset, η)
new(Ψ, η_spec, dataset, η, mach_cache)
end
end

Expand Down Expand Up @@ -78,14 +79,15 @@ Main entrypoint to run the TMLE procedure.
- dataset: A tabular dataset respecting the Table.jl interface
- verbosity: The logging level
- threshold: To avoid small values of Ĝ to cause the "clever covariate" to explode
- mach_cache: Whether underlying MLJ.machines will cache data or not
"""
function tmle::Parameter, η_spec::NuisanceSpec, dataset; verbosity=1, threshold=1e-8, mach_cache=false)
cache = TMLECache(Ψ, η_spec, dataset)
return tmle!(cache; verbosity=verbosity, threshold=threshold, mach_cache=mach_cache)
cache = TMLECache(Ψ, η_spec, dataset, mach_cache)
return tmle!(cache; verbosity=verbosity, threshold=threshold)
end

function tmle!(cache; verbosity=1, threshold=1e-8, mach_cache=false)
Ψ, η_spec, dataset, η = cache.Ψ, cache.η_spec, cache.dataset, cache.η
function tmle!(cache; verbosity=1, threshold=1e-8)
Ψ, η_spec, dataset, η, mach_cache = cache.Ψ, cache.η_spec, cache.dataset, cache.η, cache.mach_cache
# Initial fit
verbosity >= 1 && @info "Fitting the nuisance parameters..."
TMLE.fit!(η, η_spec, Ψ, dataset, verbosity=verbosity, mach_cache=mach_cache)
Expand Down Expand Up @@ -124,7 +126,7 @@ end
Runs the TMLE procedure for the new nuisance parameters specification η_spec while potentially reusing cached nuisance parameters.
"""
function tmle!(cache::TMLECache, η_spec::NuisanceSpec; verbosity=1, threshold=1e-8)
function tmle!(cache::TMLECache, η_spec::NuisanceSpec; verbosity=1, threshold=1e-8, mach_cache=false)
update!(cache, η_spec)
tmle!(cache, verbosity=verbosity, threshold=threshold)
end
Expand All @@ -138,4 +140,15 @@ while potentially reusing cached nuisance parameters.
function tmle!(cache::TMLECache, Ψ::Parameter, η_spec::NuisanceSpec; verbosity=1, threshold=1e-8)
update!(cache, Ψ, η_spec)
tmle!(cache, verbosity=verbosity, threshold=threshold)
end

"""
tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Parameter; verbosity=1, threshold=1e-8)
Runs the TMLE procedure for the new parameter Ψ and the new nuisance parameters specification η_spec
while potentially reusing cached nuisance parameters.
"""
function tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Parameter; verbosity=1, threshold=1e-8)
update!(cache, Ψ, η_spec)
tmle!(cache, verbosity=verbosity, threshold=threshold)
end

0 comments on commit c286d65

Please sign in to comment.