Skip to content

Commit b2334af

Browse files
committed
AbstractInterpreter: define new infresult_iterator interface
to make it easier to customize the behaviors of post processing of `_typeinf`. Especially, this change is motivated by a need for JET, whose post processing requires references of `InferenceState`s. Separated from #43994.
1 parent a0093d2 commit b2334af

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

base/compiler/typeinfer.jl

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,17 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
227227
# with no active ip's, frame is done
228228
frames = frame.callers_in_cycle
229229
isempty(frames) && push!(frames, frame)
230+
# collect results for the new expanded frame
231+
finish_infstates!(interp, frames)
232+
results = infresult_iterator(interp, frames)
233+
# run optimization on results in the resolved cycle
234+
optimize_results!(interp, results)
235+
# now cache the optimized results
236+
cache_results!(interp, results)
237+
return true
238+
end
239+
240+
function finish_infstates!(interp::AbstractInterpreter, frames::Vector{InferenceState})
230241
valid_worlds = WorldRange()
231242
for caller in frames
232243
@assert !(caller.dont_work_on_me)
@@ -240,14 +251,27 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
240251
# finalize and record the linfo result
241252
caller.inferred = true
242253
end
243-
# collect results for the new expanded frame
244-
results = Tuple{InferenceResult, Vector{Any}, Bool}[
245-
( frames[i].result,
246-
frames[i].stmt_edges[1]::Vector{Any},
247-
frames[i].cached )
254+
end
255+
256+
struct InfResultInfo
257+
caller::InferenceResult
258+
edges::Vector{Any}
259+
cached::Bool
260+
end
261+
262+
# returns iterator on which `optimize_results!` and `cache_results!` work on
263+
function infresult_iterator(_::AbstractInterpreter, frames::Vector{InferenceState})
264+
results = InfResultInfo[ InfResultInfo(
265+
frames[i].result,
266+
frames[i].stmt_edges[1]::Vector{Any},
267+
frames[i].cached )
248268
for i in 1:length(frames) ]
249-
empty!(frames)
250-
for (caller, _, _) in results
269+
empty!(frames) # discard `InferenceState` now
270+
return results
271+
end
272+
273+
function optimize_results!(interp::AbstractInterpreter, results::Vector{InfResultInfo})
274+
for (; caller) in results
251275
opt = caller.src
252276
if opt isa OptimizationState # implies `may_optimize(interp) === true`
253277
analyzed = optimize(interp, opt, OptimizationParams(interp), caller)
@@ -262,7 +286,10 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
262286
caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]
263287
end
264288
end
265-
for (caller, edges, cached) in results
289+
end
290+
291+
function cache_results!(interp::AbstractInterpreter, results::Vector{InfResultInfo})
292+
for (; caller, edges, cached) in results
266293
valid_worlds = caller.valid_worlds
267294
if last(valid_worlds) >= get_world_counter()
268295
# if we aren't cached, we don't need this edge
@@ -274,7 +301,6 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
274301
end
275302
finish!(interp, caller)
276303
end
277-
return true
278304
end
279305

280306
function CodeInstance(result::InferenceResult, @nospecialize(inferred_result),
@@ -348,7 +374,8 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
348374
end
349375

350376
function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodInstance,
351-
valid_worlds::WorldRange, @nospecialize(inferred_result))
377+
valid_worlds::WorldRange, result::InferenceResult)
378+
inferred_result = result.src
352379
# If we decided not to optimize, drop the OptimizationState now.
353380
# External interpreters can override as necessary to cache additional information
354381
if inferred_result isa OptimizationState
@@ -383,7 +410,7 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult)
383410

384411
# TODO: also don't store inferred code if we've previously decided to interpret this function
385412
if !already_inferred
386-
inferred_result = transform_result_for_cache(interp, linfo, valid_worlds, result.src)
413+
inferred_result = transform_result_for_cache(interp, linfo, valid_worlds, result)
387414
relocatability = isa(inferred_result, Vector{UInt8}) ? inferred_result[end] : UInt8(0)
388415
code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds, relocatability)
389416
end

0 commit comments

Comments
 (0)