Skip to content

Commit

Permalink
Type stability fixes surrounding passing of threadlocal variable
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Jan 20, 2022
1 parent 6603529 commit 5e6ae4c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
42 changes: 22 additions & 20 deletions src/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,25 @@ function add_var!(q, argtup, gcpres, ::Type{T}, argtupname, gcpresname, k) where
end
end
@generated function _batch_no_reserve(
f!::F, threadmask_tuple::NTuple{N}, nthread_tuple, torelease_tuple, Nr, Nd, ulen, args::Vararg{Any,K}; threadlocal::Bool=false
) where {F,K,N}
f!::F, threadmask_tuple::NTuple{N}, nthread_tuple, torelease_tuple, Nr, Nd, ulen, args::Vararg{Any,K}; threadlocal::Val{thread_local}=Val(false)
) where {F,K,N,thread_local}
q = quote
$(Expr(:meta,:inline))
# threads = UnsignedIteratorEarlyStop(threadmask, nthread)
# threads_tuple = map(UnsignedIteratorEarlyStop, threadmask_tuple, nthread_tuple)
# nthread_total = sum(nthread_tuple)
Ndp = Nd + one(Nd)
end
launch_quote = if thread_local
:(launch_batched_thread!(cfunc, tid, argtup, start, stop, i%UInt))
else
:(launch_batched_thread!(cfunc, tid, argtup, start, stop))
end
rem_quote = if thread_local
:(f!(arguments, (start+one(UInt)) % Int, ulen % Int, (sum(nthread_tuple)+1)%Int))
else
:(f!(arguments, (start+one(UInt)) % Int, ulen % Int))
end
block = quote
start = zero(UInt)
tid = 0x00000000
Expand All @@ -92,20 +102,12 @@ end
tz += 0x00000001
tid += tz
tm >>>= tz
if threadlocal
launch_batched_thread!(cfunc, tid, argtup, start, stop, i%UInt)
else
launch_batched_thread!(cfunc, tid, argtup, start, stop)
end
$launch_quote
start = stop
end
Nr -= nthread
end
if threadlocal
f!(arguments, (start+one(UInt)) % Int, ulen % Int, (sum(nthread_tuple)+1)%Int)
else
f!(arguments, (start+one(UInt)) % Int, ulen % Int)
end
$rem_quote
for (threadmask, nthread, torelease) zip(threadmask_tuple, nthread_tuple, torelease_tuple)
tm = mask(UnsignedIteratorEarlyStop(threadmask, nthread))
tid = 0x00000000
Expand All @@ -127,7 +129,7 @@ end
for k 1:K
add_var!(q, argt, gcpr, args[k], :args, :gcp, k)
end
push!(q.args, :(arguments = $argt), :(argtup = Reference(arguments)), :(cfunc = batch_closure(f!, argtup, Val{false}(), Val{threadlocal}())), gcpr)
push!(q.args, :(arguments = $argt), :(argtup = Reference(arguments)), :(cfunc = batch_closure(f!, argtup, Val{false}(), Val{$thread_local}())), gcpr)
push!(q.args, nothing)
q
end
Expand Down Expand Up @@ -227,15 +229,15 @@ end


@inline function batch(
f!::F, (len, nbatches)::Tuple{Vararg{Integer,2}}, args::Vararg{Any,K}; threadlocal::Bool=false
) where {F,K}
f!::F, (len, nbatches)::Tuple{Vararg{Integer,2}}, args::Vararg{Any,K}; threadlocal::Val{thread_local}=Val{false}()
) where {F,K,thread_local}
# threads, torelease = request_threads(Base.Threads.threadid(), nbatches - one(nbatches))
threads, torelease = request_threads(nbatches - one(nbatches))
nthreads = map(length,threads)
nthread = sum(nthreads)
ulen = len % UInt
if nthread % Int32 zero(Int32)
if threadlocal
if thread_local
f!(args, one(Int), ulen % Int, 1)
else
f!(args, one(Int), ulen % Int)
Expand All @@ -246,12 +248,12 @@ end
Nd = Base.udiv_int(ulen, nbatch % UInt) # reasonable for `ulen` to be ≥ 2^32
Nr = ulen - Nd * nbatch

_batch_no_reserve(f!, map(mask,threads), nthreads, torelease, Nr, Nd, ulen, args...; threadlocal=threadlocal)
_batch_no_reserve(f!, map(mask,threads), nthreads, torelease, Nr, Nd, ulen, args...; threadlocal)
end
function batch(
f!::F, (len, nbatches, reserve_per_worker)::Tuple{Vararg{Integer,3}}, args::Vararg{Any,K}; threadlocal::Bool=false
) where {F,K}
batch(f!, (len, nbatches), args...; threadlocal=false)
f!::F, (len, nbatches, reserve_per_worker)::Tuple{Vararg{Integer,3}}, args::Vararg{Any,K}; threadlocal::Val{thread_local}=Val(false)
) where {F,K,thread_local}
batch(f!, (len, nbatches), args...; threadlocal)
# ulen = len % UInt
# if nbatches > 1
# requested_threads = reserve_per_worker*nbatches
Expand Down
2 changes: 1 addition & 1 deletion src/closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ function enclose(exorig::Expr, reserve_per, minbatchsize, per::Symbol, threadloc
push!(batchcall.args, esc(a))
end
if threadlocal !== Symbol("")
push!(batchcall.args, Expr(:kw, :threadlocal, true))
push!(batchcall.args, Expr(:kw, :threadlocal, Val(true)))
end
push!(q.args, batchcall)
quote
Expand Down

2 comments on commit 5e6ae4c

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/52856

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.3 -m "<description of version>" 5e6ae4c2ae009b507bbcf90ff2c2b9b7d5e94559
git push origin v0.6.3

Please sign in to comment.