Skip to content

Commit

Permalink
Merge pull request #34 from TuringLang/dw/chunksize
Browse files Browse the repository at this point in the history
More reasonable default values of chunk size
  • Loading branch information
yebai authored Mar 15, 2022
2 parents 382c923 + c7cad88 commit db93421
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.1.3"
version = "0.1.4"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down
9 changes: 6 additions & 3 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,13 @@ function grad!(
- vo(alg, q(θ_), model, args...)
end

chunk_size = getchunksize(typeof(alg))
# Set chunk size and do ForwardMode.
chunk = ForwardDiff.Chunk(min(length(θ), chunk_size))
config = ForwardDiff.GradientConfig(f, θ, chunk)
chunk_size = getchunksize(typeof(alg))
config = if chunk_size == 0
ForwardDiff.GradientConfig(f, θ)
else
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
end
ForwardDiff.gradient!(out, f, θ, config)
end

Expand Down
9 changes: 3 additions & 6 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ function setadbackend(::Val{:forward_diff})
setadbackend(Val(:forwarddiff))
end
function setadbackend(::Val{:forwarddiff})
CHUNKSIZE[] == 0 && setchunksize(40)
ADBACKEND[] = :forwarddiff
end

Expand All @@ -26,13 +25,11 @@ function setadsafe(switch::Bool)
ADSAFE[] = switch
end

const CHUNKSIZE = Ref(40) # default chunksize used by AD
const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically

function setchunksize(chunk_size::Int)
if ~(CHUNKSIZE[] == chunk_size)
@info("[AdvancedVI]: AD chunk size is set as $chunk_size")
CHUNKSIZE[] = chunk_size
end
@info("[AdvancedVI]: AD chunk size is set as $chunk_size")
CHUNKSIZE[] = chunk_size
end

abstract type ADBackend end
Expand Down

2 comments on commit db93421

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/56670

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.4 -m "<description of version>" db93421c2b7de180210dd42919ed3ad303b0c9d0
git push origin v0.1.4

Please sign in to comment.