Skip to content

Commit 2d12daa

Browse files
committed
incorporate method match world age into generated code
This implements feedback from the GitHub issue comment at #66 (comment). Specifically, we now use the world age from method matches by directly utilizing `Base.Compiler.findall`. However, the updated tests pass even without this change, which suggests there may be issues with the test coverage.
1 parent f590a4d commit 2d12daa

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

src/CassetteOverlay.jl

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,16 @@ function generate_overlay_src(
5959
tt = Base.to_tuple_type(fargtypes)
6060
mt_worlds = methodtable(world, passtype)
6161
if mt_worlds isa Pair
62-
method_table, worlds = mt_worlds
62+
method_table, mtworlds = mt_worlds
6363
else
6464
method_table = mt_worlds
65-
worlds = nothing
65+
mtworlds = nothing
6666
end
6767
match = Base._which(tt; method_table, raise = false, world)
68-
match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
68+
results = Base.Compiler.findall(tt, method_table; limit=1)
69+
length(results) == 1 || return nothing # method match failed – the fallback implementation will raise a proper MethodError
70+
match = results[1]
71+
match_worlds = results.valid_worlds
6972
mi = Core.Compiler.specialize_method(match)
7073
src = Core.Compiler.retrieve_code_info(mi, world)
7174
src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it
@@ -78,12 +81,24 @@ function generate_overlay_src(
7881
push!(invalid_code, (world, source, passtype, fargtypes, src, selfname, fargsname))
7982
# TODO `return nothing` when updating the minimum compat to 1.12
8083
end
81-
if worlds !== nothing
82-
src.min_world, src.max_world = max(src.min_world, first(worlds)), min(src.max_world, last(worlds))
84+
if mtworlds !== nothing
85+
src.min_world, src.max_world = max(src.min_world, first(mtworlds)), min(src.max_world, last(mtworlds))
8386
end
87+
src.min_world, src.max_world = max(src.min_world, first(match_worlds)), min(src.max_world, last(match_worlds))
8488
return src
8589
end
8690

91+
function get_mt_worlds(m::Module, var::Symbol, world::UInt)
92+
@static if VERSION v"1.12-"
93+
@assert isconst_at_world(m, var, world)
94+
mt, worlds = getglobal_at_world(m, var, world)
95+
return Base.Compiler.OverlayMethodTable(world, mt::MethodTable) => worlds
96+
else
97+
@assert @invokelatest isconst(M, S)
98+
return getglobal(M, S)::MethodTable
99+
end
100+
end
101+
87102
macro overlaypass(args...)
88103
if length(args) == 1
89104
PassName = nothing
@@ -92,6 +107,10 @@ macro overlaypass(args...)
92107
PassName, method_table = args
93108
end
94109

110+
if !(method_table === nothing || method_table isa Symbol || Meta.isexpr(method_table, :.))
111+
error("Unexpected @overlaypass call")
112+
end
113+
95114
if PassName === nothing
96115
PassName = esc(gensym(string(method_table)))
97116
decl_pass = :(struct $PassName <: $OverlayPass end)
@@ -104,9 +123,22 @@ macro overlaypass(args...)
104123

105124
nonoverlaytype = typeof(CassetteOverlay.nonoverlay)
106125

107-
if method_table !== :nothing
108-
mthd_tbl = :($CassetteOverlay.methodtable(world::UInt, ::Type{$PassName}) =
109-
Base.Compiler.OverlayMethodTable(world, $(esc(method_table))))
126+
if method_table isa Symbol
127+
mthd_tbl = :(
128+
function $CassetteOverlay.methodtable(world::UInt, ::Type{$PassName})
129+
return $CassetteOverlay.get_mt_worlds($__module__, $(QuoteNode(method_table)), world)
130+
end
131+
)
132+
elseif Meta.isexpr(method_table, :.)
133+
M, S = method_table.args
134+
if !(M isa Symbol && S isa QuoteNode && S.value isa Symbol)
135+
error("Unexpected @overlaypass call")
136+
end
137+
mthd_tbl = :(
138+
function $CassetteOverlay.methodtable(world::UInt, ::Type{$PassName})
139+
return $CassetteOverlay.get_mt_worlds($(esc(M)), $S, world)
140+
end
141+
)
110142
else
111143
mthd_tbl = nothing
112144
end
@@ -189,22 +221,12 @@ end
189221

190222
abstract type AbstractBindingOverlay{M, S} <: OverlayPass; end
191223
function methodtable(world::UInt, ::Type{<:AbstractBindingOverlay{M, S}}) where {M, S}
192-
if M === nothing
193-
return nothing
194-
end
195-
@static if VERSION v"1.12-"
196-
@assert isconst_at_world(M, S, world)
197-
mt, worlds = getglobal_at_world(M, S, world)
198-
return Base.Compiler.OverlayMethodTable(world, mt::MethodTable) => worlds
199-
else
200-
@assert @invokelatest isconst(M, S)
201-
return getglobal(M, S)::MethodTable
202-
end
224+
(M isa Module && S isa Symbol) || error("Unexpected AbstractBindingOverlay type")
225+
return get_mt_worlds(M, S, world)
203226
end
204227
@overlaypass AbstractBindingOverlay nothing
205228

206-
struct Overlay{M, S} <: AbstractBindingOverlay{M, S}
207-
end
229+
struct Overlay{M, S} <: AbstractBindingOverlay{M, S} end
208230
function Overlay(mt::MethodTable)
209231
@assert @invokelatest isconst(mt.module, mt.name)
210232
@assert mt === @invokelatest getglobal(mt.module, mt.name)

test/simple.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@ myidentity(@nospecialize x) = x
1010
kwifelse(x, y; cond=true) = ifelse(cond, x, y)
1111

1212
# run overlayed methods
13-
@overlay SimpleTable myidentity(@nospecialize x) = 42
14-
@test pass(myidentity, nothing) == 42
15-
@test pass() do
16-
myidentity(nothing)
17-
end == 42
13+
@overlay SimpleTable myidentity(@nospecialize x) = (@noinline; (println(devnull, "prevent inlining")); 42)
14+
call_myidentity() = @noinline myidentity(nothing)
15+
@test pass(call_myidentity) == 42
1816

1917
# kwargs
2018
@overlay SimpleTable kwifelse(x, y; cond=true) = ifelse(cond, y, x)
@@ -28,8 +26,8 @@ let (x, y) = (0, 1)
2826
end
2927

3028
# method invalidation
31-
@overlay SimpleTable myidentity(@nospecialize x) = 0
32-
@test pass(myidentity, nothing) == 0
29+
@overlay SimpleTable myidentity(@nospecialize x) = (@noinline; (println(devnull, "prevent inlining")); 0)
30+
@test pass(call_myidentity) == 0
3331

3432
# nonoverlay
3533
@overlay SimpleTable myidentity(@nospecialize x) = nonoverlay(myidentity, x)

0 commit comments

Comments
 (0)