Skip to content

Commit 1b14a4a

Browse files
committed
Use maxthreadid() in TSVI
1 parent e570615 commit 1b14a4a

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

src/threadsafe.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo)
1313
# fields. This is not good practice --- see
1414
# https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full
1515
# explanation --- but it has worked okay so far.
16-
# The use of nthreads()*2 here ensures that threadid() doesn't exceed
17-
# the length of the logps array. Ideally, we would use maxthreadid(),
18-
# but Mooncake can't differentiate through that. Empirically, nthreads()*2
19-
# seems to provide an upper bound to maxthreadid(), so we use that here.
20-
# See https://github.com/TuringLang/DynamicPPL.jl/pull/936
21-
accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)]
16+
accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.maxthreadid()]
2217
return ThreadSafeVarInfo(vi, accs_by_thread)
2318
end
2419
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi

0 commit comments

Comments
 (0)