Skip to content

Commit cdbb628

Browse files
committed
check for failing case
1 parent db68c59 commit cdbb628

File tree

3 files changed

+38
-26
lines changed

3 files changed

+38
-26
lines changed

projects/eudsl-python-extras/mlir/extras/ast/util.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def ast_call(name, args=None, keywords=None):
3838
def get_module_cst(f):
3939
f_src = dedent(inspect.getsource(f))
4040
tree = ast.parse(f_src)
41-
assert isinstance(
42-
tree.body[0], ast.FunctionDef
43-
), f"unexpected ast node {tree.body[0]}"
41+
assert isinstance(tree.body[0], ast.FunctionDef), (
42+
f"unexpected ast node {tree.body[0]}"
43+
)
4444
return tree
4545

4646

@@ -92,7 +92,7 @@ def replace_closure(code, new_closure: Dict):
9292
LOAD_DEREF = opmap["LOAD_DEREF"]
9393

9494
# get the orig localplus that will be loaded from by the orig bytecode LOAD_DEREF arg_i
95-
localsplus, localsplus_name_to_idx = get_localsplus_name_to_idx(code)
95+
localsplus, _localsplus_name_to_idx = get_localsplus_name_to_idx(code)
9696

9797
# closure vars go into co_freevars
9898
new_code = code.replace(co_freevars=tuple(new_closure.keys()))
@@ -167,9 +167,9 @@ def copy_func(f, new_closure: Dict = None):
167167

168168
def append_hidden_node(node_body, new_node):
169169
last_statement = node_body[-1]
170-
assert (
171-
last_statement.end_lineno is not None
172-
), f"last_statement {ast.unparse(last_statement)} must have end_lineno"
170+
assert last_statement.end_lineno is not None, (
171+
f"last_statement {ast.unparse(last_statement)} must have end_lineno"
172+
)
173173
new_node = ast.fix_missing_locations(
174174
set_lineno(new_node, last_statement.end_lineno)
175175
)

projects/eudsl-python-extras/mlir/extras/dialects/func.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ def __call__(self, *call_args):
298298
return call(self.emit(*call_args), call_args)
299299

300300
def __getitem__(self, item):
301+
if not isinstance(item, tuple):
302+
item = (item,)
301303
if self.generics is None:
302304
raise RuntimeError(
303305
"using a generic call requires the func be generic (i.e., have type_params)"
@@ -322,22 +324,26 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
322324
continue
323325
if k not in already_reified_type_params:
324326
raise RuntimeError(
325-
f"typevar {k} not reified prior to evaluating dependent typevar {t}"
327+
f"typevar {k} not reified prior to evaluating dependent typevar {tvar}"
326328
)
327329
cvrs[k] = already_reified_type_params[k]
328330
unevaled_type_data = copy_func(unevaled_type_data, cvrs)
329331
return unevaled_type_data()
330332

331-
generics = copy.deepcopy(self.generics)
332-
for i, t in enumerate(generics):
333+
generics = copy_object(self.generics)
334+
for i, tvar in enumerate(generics):
335+
if not isinstance(tvar, TypeVar):
336+
raise RuntimeError(
337+
f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, you should give it a unique name."
338+
)
333339
type_var_default = None
334340
if sys.version_info >= (3, 12):
335-
type_var = PyTypeVarObject.from_object(t)
341+
type_var = PyTypeVarObject.from_object(tvar)
336342
type_var_bound = type_var.bound
337-
if sys.version_info >= (3, 13) and t.has_default():
343+
if sys.version_info >= (3, 13) and tvar.has_default():
338344
type_var_default = type_var.default_value
339345
else:
340-
type_var_bound = t.__bound__
346+
type_var_bound = tvar.__bound__
341347

342348
if bool(type_var_bound):
343349
# before 3.12 typevar was just a python class
@@ -346,7 +352,7 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
346352
type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound)
347353
elif not bool(type_var_default):
348354
if i >= len(item):
349-
raise RuntimeError(f"generic {t} must have concrete val")
355+
raise RuntimeError(f"generic {tvar=} must have concrete val")
350356
if isinstance(item[i], Type):
351357
type_var_bound = "type"
352358
else:
@@ -358,29 +364,31 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
358364
val = type_var_default
359365
else:
360366
if i >= len(item):
361-
raise RuntimeError(f"generic {t} must have concrete val")
367+
raise RuntimeError(f"generic {tvar=} must have concrete val")
362368
val = item[i]
363369

364-
r = ReifiedTypeParams(t.__name__, val, type_var_bound)
370+
r = ReifiedTypeParams(tvar.__name__, val, type_var_bound)
365371

366372
reified_type_params.append(r)
367373
already_reified_type_params[r.name] = r.val
368374

369-
if t.__name__ in body_builder.__globals__:
370-
body_builder.__globals__[t.__name__] = r.val
375+
# replace the tvar in body_builder's global context with the reified val
376+
if tvar.__name__ in body_builder.__globals__:
377+
body_builder.__globals__[tvar.__name__] = r.val
378+
# replace the tvar in body_builder's closure with the reified val
371379
if r.name in body_builder.__code__.co_freevars:
372380
free_i = body_builder.__code__.co_freevars.index(r.name)
373381
assert (
374-
body_builder.__closure__[free_i].cell_contents == t
382+
body_builder.__closure__[free_i].cell_contents == tvar
375383
), "typevars don't match"
376384
body_builder.__closure__[free_i].cell_contents = r.val
377385

378386
name_mangled_generics = []
379387
for r in reified_type_params:
380-
t, v = r.type, r.val
388+
tvar, v = r.type, r.val
381389
if callable(v):
382390
v = v.__name__
383-
name_mangled_generics.append(f"{t}_{v}")
391+
name_mangled_generics.append(f"{tvar}_{v}")
384392

385393
return FuncBase(
386394
body_builder,

projects/eudsl-python-extras/tests/dialect/test_func.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ def mat_product_kernel2():
232232
one = arith.constant(1, T.f32())
233233
two = _op(one, one)
234234

235-
mat_product_kernel1[arith.maximumf,].emit()
236-
mat_product_kernel2[arith.minimumf,].emit()
237-
mat_product_kernel2[arith.maximumf,].emit()
235+
mat_product_kernel1[arith.maximumf].emit()
236+
mat_product_kernel2[arith.minimumf].emit()
237+
mat_product_kernel2[arith.maximumf].emit()
238238

239239
# CHECK: func.func @mat_product_kernel1_function_maximumf() {
240240
# CHECK: %cst = arith.constant 1.000000e+00 : f32
@@ -265,7 +265,7 @@ def test_global_closures(ctx: MLIRContext):
265265
def _generic_pool2d_scf(a: T.f32(), b: T.f32()):
266266
_op(a, b)
267267

268-
_maxpool2d_scf = _generic_pool2d_scf[arith.maximumf,]
268+
_maxpool2d_scf = _generic_pool2d_scf[arith.maximumf]
269269

270270
# _op = TypeVar("_op")
271271

@@ -276,7 +276,11 @@ def _generic_pool3d_scf(
276276
):
277277
_op(a, b)
278278

279-
_maxpool3d_scf = _generic_pool3d_scf[arith.maximumf,]
279+
with pytest.raises(
280+
RuntimeError,
281+
match="0th generic has probably already been reified as <function maximumf at .*?>; if you're using a global tvar for the generic, you should give it a unique name.",
282+
):
283+
_maxpool3d_scf = _generic_pool3d_scf[arith.maximumf]
280284

281285

282286
def test_generics_with_canonicalizations(ctx: MLIRContext):

0 commit comments

Comments
 (0)