@@ -59,13 +59,16 @@ function generate_overlay_src(
5959 tt = Base. to_tuple_type (fargtypes)
6060 mt_worlds = methodtable (world, passtype)
6161 if mt_worlds isa Pair
62- method_table, worlds = mt_worlds
62+ method_table, mtworlds = mt_worlds
6363 else
6464 method_table = mt_worlds
65- worlds = nothing
65+ mtworlds = nothing
6666 end
6767 match = Base. _which (tt; method_table, raise = false , world)
68- match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
68+ results = Base. Compiler. findall (tt, method_table; limit= 1 )
69+ length (results) == 1 || return nothing # method match failed – the fallback implementation will raise a proper MethodError
70+ match = results[1 ]
71+ match_worlds = results. valid_worlds
6972 mi = Core. Compiler. specialize_method (match)
7073 src = Core. Compiler. retrieve_code_info (mi, world)
7174 src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it
@@ -78,12 +81,24 @@ function generate_overlay_src(
7881 push! (invalid_code, (world, source, passtype, fargtypes, src, selfname, fargsname))
7982 # TODO `return nothing` when updating the minimum compat to 1.12
8083 end
81- if worlds != = nothing
82- src. min_world, src. max_world = max (src. min_world, first (worlds )), min (src. max_world, last (worlds ))
84+ if mtworlds != = nothing
85+ src. min_world, src. max_world = max (src. min_world, first (mtworlds )), min (src. max_world, last (mtworlds ))
8386 end
87+ src. min_world, src. max_world = max (src. min_world, first (match_worlds)), min (src. max_world, last (match_worlds))
8488 return src
8589end
8690
91+ function get_mt_worlds (m:: Module , var:: Symbol , world:: UInt )
92+ @static if VERSION ≥ v " 1.12-"
93+ @assert isconst_at_world (m, var, world)
94+ mt, worlds = getglobal_at_world (m, var, world)
95+ return Base. Compiler. OverlayMethodTable (world, mt:: MethodTable ) => worlds
96+ else
97+ @assert @invokelatest isconst (M, S)
98+ return getglobal (M, S):: MethodTable
99+ end
100+ end
101+
87102macro overlaypass (args... )
88103 if length (args) == 1
89104 PassName = nothing
@@ -92,6 +107,10 @@ macro overlaypass(args...)
92107 PassName, method_table = args
93108 end
94109
110+ if ! (method_table === nothing || method_table isa Symbol || Meta. isexpr (method_table, :.))
111+ error (" Unexpected @overlaypass call" )
112+ end
113+
95114 if PassName === nothing
96115 PassName = esc (gensym (string (method_table)))
97116 decl_pass = :(struct $ PassName <: $OverlayPass end )
@@ -104,9 +123,22 @@ macro overlaypass(args...)
104123
105124 nonoverlaytype = typeof (CassetteOverlay. nonoverlay)
106125
107- if method_table != = :nothing
108- mthd_tbl = :($ CassetteOverlay. methodtable (world:: UInt , :: Type{$PassName} ) =
109- Base. Compiler. OverlayMethodTable (world, $ (esc (method_table))))
126+ if method_table isa Symbol
127+ mthd_tbl = :(
128+ function $CassetteOverlay. methodtable (world:: UInt , :: Type{$PassName} )
129+ return $ CassetteOverlay. get_mt_worlds ($ __module__, $ (QuoteNode (method_table)), world)
130+ end
131+ )
132+ elseif Meta. isexpr (method_table, :.)
133+ M, S = method_table. args
134+ if ! (M isa Symbol && S isa QuoteNode && S. value isa Symbol)
135+ error (" Unexpected @overlaypass call" )
136+ end
137+ mthd_tbl = :(
138+ function $CassetteOverlay. methodtable (world:: UInt , :: Type{$PassName} )
139+ return $ CassetteOverlay. get_mt_worlds ($ (esc (M)), $ S, world)
140+ end
141+ )
110142 else
111143 mthd_tbl = nothing
112144 end
@@ -189,22 +221,12 @@ end
189221
190222abstract type AbstractBindingOverlay{M, S} <: OverlayPass ; end
191223function methodtable (world:: UInt , :: Type{<:AbstractBindingOverlay{M, S}} ) where {M, S}
192- if M === nothing
193- return nothing
194- end
195- @static if VERSION ≥ v " 1.12-"
196- @assert isconst_at_world (M, S, world)
197- mt, worlds = getglobal_at_world (M, S, world)
198- return Base. Compiler. OverlayMethodTable (world, mt:: MethodTable ) => worlds
199- else
200- @assert @invokelatest isconst (M, S)
201- return getglobal (M, S):: MethodTable
202- end
224+ (M isa Module && S isa Symbol) || error (" Unexpected AbstractBindingOverlay type" )
225+ return get_mt_worlds (M, S, world)
203226end
204227@overlaypass AbstractBindingOverlay nothing
205228
206- struct Overlay{M, S} <: AbstractBindingOverlay{M, S}
207- end
229+ struct Overlay{M, S} <: AbstractBindingOverlay{M, S} end
208230function Overlay (mt:: MethodTable )
209231 @assert @invokelatest isconst (mt. module, mt. name)
210232 @assert mt === @invokelatest getglobal (mt. module, mt. name)
0 commit comments