8484
8585include (" compiler/ssair/driver.jl" )
8686
87- mutable struct OptimizationState
87+ struct OptimizationState
8888 linfo:: MethodInstance
8989 src:: CodeInfo
90- ir:: Union{Nothing, IRCode}
9190 stmt_info:: Vector{Any}
9291 mod:: Module
9392 sptypes:: Vector{Any} # static parameters
@@ -99,8 +98,7 @@ mutable struct OptimizationState
9998 EdgeTracker (s_edges, frame. valid_worlds),
10099 WorldView (code_cache (interp), frame. world),
101100 interp)
102- return new (frame. linfo,
103- frame. src, nothing , frame. stmt_info, frame. mod,
101+ return new (frame. linfo, frame. src, frame. stmt_info, frame. mod,
104102 frame. sptypes, frame. slottypes, inlining)
105103 end
106104 function OptimizationState (linfo:: MethodInstance , src:: CodeInfo , params:: OptimizationParams , interp:: AbstractInterpreter )
@@ -127,8 +125,7 @@ mutable struct OptimizationState
127125 nothing ,
128126 WorldView (code_cache (interp), get_world_counter ()),
129127 interp)
130- return new (linfo,
131- src, nothing , stmt_info, mod,
128+ return new (linfo, src, stmt_info, mod,
132129 sptypes_from_meth_instance (linfo), slottypes, inlining)
133130 end
134131end
@@ -139,11 +136,10 @@ function OptimizationState(linfo::MethodInstance, params::OptimizationParams, in
139136 return OptimizationState (linfo, src, params, interp)
140137end
141138
142- function ir_to_codeinf! (opt:: OptimizationState )
139+ function ir_to_codeinf! (opt:: OptimizationState , ir :: IRCode )
143140 (; linfo, src) = opt
144141 optdef = linfo. def
145- replace_code_newstyle! (src, opt. ir:: IRCode , isa (optdef, Method) ? Int (optdef. nargs) : 0 )
146- opt. ir = nothing
142+ replace_code_newstyle! (src, ir, isa (optdef, Method) ? Int (optdef. nargs) : 0 )
147143 widen_all_consts! (src)
148144 src. inferred = true
149145 # finish updating the result struct
@@ -380,130 +376,155 @@ struct ConstAPI
380376end
381377
382378"""
383- finish(interp::AbstractInterpreter, opt::OptimizationState,
384- params::OptimizationParams, ir::IRCode, caller::InferenceResult) -> analyzed::Union{Nothing,ConstAPI}
385-
386- Post process information derived by Julia-level optimizations for later uses:
387- - computes "purity", i.e. side-effect-freeness
388- - computes inlining cost
389-
390- In a case when the purity is proven, `finish` can return `ConstAPI` object wrapping the constant
391- value so that the runtime system will use the constant calling convention for the method calls.
379+ finish!(interp::AbstractInterpreter,
380+ opt::OptimizationState, ir::IRCode, caller::InferenceResult)
381+
382+ Runs post-Julia-level optimization process and caches information for later uses:
383+ - computes "purity" (i.e. side-effect-freeness) of the optimized frame
384+ - computes inlining cost and cache the inlineability in `opt.src.inlineable`
385+ - stores the result of optimization in `caller.src`
386+ * by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir`
387+ * in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant
388+ value will be kept in `caller.src` instead, so that the runtime system will use
389+ the constant calling convention
390+
391+ !!! note
392+ The lifetimes of `opt` and `ir` end by the end of this process.
393+ Still external `AbstractInterpreter` can override `transform_optresult_for_cache`
394+ as necessary to cache them. Note that `transform_result_for_cache` should be overloaded
395+ also in such cases, otherwise the default implmentation of `transform_result_for_cache`
396+ will discard any information other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`.
392397"""
393- function finish (interp:: AbstractInterpreter , opt:: OptimizationState ,
394- params:: OptimizationParams , ir:: IRCode , caller:: InferenceResult )
395- (; src, linfo) = opt
396- (; def, specTypes) = linfo
397-
398- analyzed = nothing # `ConstAPI` if this call can use constant calling convention
399- force_noinline = _any (@nospecialize (x) -> isexpr (x, :meta ) && x. args[1 ] === :noinline , ir. meta)
398+ function finish! (interp:: AbstractInterpreter ,
399+ opt:: OptimizationState , ir:: IRCode , caller:: InferenceResult )
400+ src = opt. src
400401
401- # compute inlining and other related optimizations
402402 result = caller. result
403403 @assert ! (result isa LimitedAccuracy)
404404 result = isa (result, InterConditional) ? widenconditional (result) : result
405- if (isa (result, Const) || isconstType (result))
406- proven_pure = false
407- # must be proven pure to use constant calling convention;
408- # otherwise we might skip throwing errors (issue #20704)
409- # TODO : Improve this analysis; if a function is marked @pure we should really
410- # only care about certain errors (e.g. method errors and type errors).
411- if length (ir. stmts) < 15
412- proven_pure = true
413- for i in 1 : length (ir. stmts)
414- node = ir. stmts[i]
415- stmt = node[:inst ]
416- if stmt_affects_purity (stmt, ir) && ! stmt_effect_free (stmt, node[:type ], ir)
417- proven_pure = false
418- break
419- end
420- end
421- if proven_pure
422- for fl in src. slotflags
423- if (fl & SLOT_USEDUNDEF) != 0
424- proven_pure = false
425- break
426- end
427- end
428- end
429- end
430405
431- if proven_pure
432- # use constant calling convention
433- # Do not emit `jl_fptr_const_return` if coverage is enabled
434- # so that we don't need to add coverage support
435- # to the `jl_call_method_internal` fast path
436- # Still set pure flag to make sure `inference` tests pass
437- # and to possibly enable more optimization in the future
438- src . pure = true
406+ newresult = nothing # ConstAPI if this call can use constant calling convention
407+ if isa (result, Const) || isconstType (result)
408+ # computes "purity" (i.e. side-effect-freeness)
409+ if compute_purity (ir, src)
410+ src . inlineable = src . pure = true
411+
412+ # must be proven pure to use constant calling convention;
413+ # otherwise we might skip throwing errors (issue #20704)
439414 if isa (result, Const)
440415 val = result. val
441416 if is_inlineable_constant (val)
442- analyzed = ConstAPI (val)
417+ newresult = ConstAPI (val)
443418 end
444419 else
445420 @assert isconstType (result)
446- analyzed = ConstAPI (result. parameters[1 ])
421+ newresult = ConstAPI (result. parameters[1 ])
447422 end
448- force_noinline || (src. inlineable = true )
449423 end
450424 end
451425
452- opt. ir = ir
453-
454426 # determine and cache inlineability
455- union_penalties = false
456- if ! force_noinline
457- sig = unwrap_unionall (specTypes)
458- if isa (sig, DataType) && sig. name === Tuple. name
459- for P in sig. parameters
460- P = unwrap_unionall (P)
461- if isa (P, Union)
462- union_penalties = true
463- break
464- end
427+ src. inlineable = compute_inlineability (ir, opt, result, src. inlineable)
428+
429+ caller. valid_worlds = (opt. inlining. et:: EdgeTracker ). valid_worlds[]
430+
431+ caller. src = transform_optresult_for_cache (interp, opt, ir, newresult)
432+
433+ return nothing
434+ end
435+
436+ function compute_purity (ir:: IRCode , src:: CodeInfo )
437+ # TODO : Improve this analysis; if a function is marked @pure we should really
438+ # only care about certain errors (e.g. method errors and type errors).
439+ if length (ir. stmts) < 15
440+ for i in 1 : length (ir. stmts)
441+ node = ir. stmts[i]
442+ stmt = node[:inst ]
443+ if stmt_affects_purity (stmt, ir) && ! stmt_effect_free (stmt, node[:type ], ir)
444+ return false
465445 end
466- else
467- force_noinline = true
468446 end
469- if ! src. inlineable && result === Bottom
470- force_noinline = true
447+ for flag in src. slotflags
448+ if (flag & SLOT_USEDUNDEF) != 0
449+ return false
450+ end
471451 end
452+ return true
472453 end
473- if force_noinline
474- src. inlineable = false
475- elseif isa (def, Method)
476- if src. inlineable && isdispatchtuple (specTypes)
477- # obey @inline declaration if a dispatch barrier would not help
478- else
479- # compute the cost (size) of inlining this code
480- cost_threshold = default = params. inline_cost_threshold
481- if result ⊑ Tuple && ! isconcretetype (widenconst (result))
482- cost_threshold += params. inline_tupleret_bonus
483- end
484- # if the method is declared as `@inline`, increase the cost threshold 20x
485- if src. inlineable
486- cost_threshold += 19 * default
487- end
488- # a few functions get special treatment
489- if def. module === _topmod (def. module)
490- name = def. name
491- if name === :iterate || name === :unsafe_convert || name === :cconvert
492- cost_threshold += 4 * default
493- end
454+ return false
455+ end
456+
457+ function compute_inlineability (ir:: IRCode , opt:: OptimizationState , @nospecialize (result),
458+ declared_inlineability:: Bool )
459+ (; def, specTypes) = opt. linfo
460+ force_noinline = _any (@nospecialize (x) -> isexpr (x, :meta ) && x. args[1 ] === :noinline , ir. meta)
461+ force_noinline && return false
462+ union_penalties = false
463+ sig = unwrap_unionall (specTypes)
464+ if isa (sig, DataType) && sig. name === Tuple. name
465+ for P in sig. parameters
466+ P = unwrap_unionall (P)
467+ if isa (P, Union)
468+ union_penalties = true
469+ break
494470 end
495- src. inlineable = inline_worthy (ir, params, union_penalties, cost_threshold)
496471 end
472+ else
473+ return false
474+ end
475+ if ! declared_inlineability && result === Bottom
476+ return false
477+ end
478+ isa (def, Method) || return declared_inlineability
479+ if declared_inlineability && isdispatchtuple (specTypes)
480+ # obey @inline declaration if a dispatch barrier would not help
481+ return true
482+ end
483+ # compute the cost (size) of inlining this code
484+ params = opt. inlining. params
485+ cost_threshold = default = params. inline_cost_threshold
486+ if result ⊑ Tuple && ! isconcretetype (widenconst (result))
487+ cost_threshold += params. inline_tupleret_bonus
497488 end
489+ # if the method is declared as `@inline`, increase the cost threshold 20x
490+ if declared_inlineability
491+ cost_threshold += 19 * default
492+ end
493+ # a few functions get special treatment
494+ if def. module === _topmod (def. module)
495+ name = def. name
496+ if name === :iterate || name === :unsafe_convert || name === :cconvert
497+ cost_threshold += 4 * default
498+ end
499+ end
500+ return inline_worthy (ir, params, union_penalties, cost_threshold)
501+ end
498502
499- return analyzed
503+ function transform_optresult_for_cache (:: AbstractInterpreter ,
504+ opt:: OptimizationState , ir:: IRCode , @nospecialize (newresult))
505+ if isa (newresult, ConstAPI)
506+ # use constant calling convention
507+ # Do not emit `jl_fptr_const_return` if coverage is enabled
508+ # so that we don't need to add coverage support
509+ # to the `jl_call_method_internal` fast path
510+ # Still set pure flag to make sure `inference` tests pass
511+ # and to possibly enable more optimization in the future
512+
513+ # XXX : The work in ir_to_codeinf! is essentially wasted. The only reason
514+ # we're doing it is so that code_llvm can return the code
515+ # for the `return ...::Const` (which never runs anyway). We should do this
516+ # as a post processing step instead.
517+ ir_to_codeinf! (opt, ir)
518+ return newresult
519+ end
520+ return ir_to_codeinf! (opt, ir)
500521end
501522
502523# run the optimization work
503- function optimize (interp:: AbstractInterpreter , opt :: OptimizationState ,
504- params :: OptimizationParams , caller:: InferenceResult )
524+ function optimize! (interp:: AbstractInterpreter ,
525+ opt :: OptimizationState , caller:: InferenceResult )
505526 @timeit " optimizer" ir = run_passes (opt. src, opt, caller)
506- return finish (interp, opt, params , ir, caller)
527+ @timeit " finish! " finish! (interp, opt, ir, caller)
507528end
508529
509530using . EscapeAnalysis
0 commit comments