Skip to content

Commit 8573df5

Browse files
committed
[eudsl-python-extras] fix more func stuff
1 parent e394189 commit 8573df5

File tree

8 files changed

+687
-733
lines changed

8 files changed

+687
-733
lines changed

.github/workflows/build_test_release_eudsl_python_extras.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ jobs:
6464

6565
- name: "Build eudsl-python-extras sdist"
6666
run: |
67+
6768
SHA_SHORT="$(git rev-parse --short HEAD)"
6869
WHEEL_VERSION="$(date +'%Y%m%d.%H%M')+$SHA_SHORT"
6970
pushd projects/eudsl-python-extras
@@ -175,7 +176,16 @@ jobs:
175176
run: python -m pip install dist/eudsl_python_extras*.tar.gz
176177

177178
- name: "Test eudsl-python-extras"
178-
run: python -m pytest projects/eudsl-python-extras/tests
179+
run: |
180+
181+
DESELECT=""
182+
if [[ $(python -c "print(__import__('sys').version_info < (3, 13))") == "True" ]]; then
183+
DESELECT="--deselect projects/eudsl-python-extras/tests/dialect/test_generics.py"
184+
fi
185+
186+
echo $DESELECT
187+
188+
python -m pytest projects/eudsl-python-extras/tests $DESELECT
179189
180190
release-eudsl-python-extras:
181191

projects/eudsl-python-extras/mlir/extras/ast/util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from bytecode import ConcreteBytecode
1414
from cloudpickle import cloudpickle
15+
1516
from ...ir import Type
1617

1718

@@ -38,9 +39,9 @@ def ast_call(name, args=None, keywords=None):
3839
def get_module_cst(f):
3940
f_src = dedent(inspect.getsource(f))
4041
tree = ast.parse(f_src)
41-
assert isinstance(tree.body[0], ast.FunctionDef), (
42-
f"unexpected ast node {tree.body[0]}"
43-
)
42+
assert isinstance(
43+
tree.body[0], ast.FunctionDef
44+
), f"unexpected ast node {tree.body[0]}"
4445
return tree
4546

4647

@@ -167,9 +168,9 @@ def copy_func(f, new_closure: Dict = None):
167168

168169
def append_hidden_node(node_body, new_node):
169170
last_statement = node_body[-1]
170-
assert last_statement.end_lineno is not None, (
171-
f"last_statement {ast.unparse(last_statement)} must have end_lineno"
172-
)
171+
assert (
172+
last_statement.end_lineno is not None
173+
), f"last_statement {ast.unparse(last_statement)} must have end_lineno"
173174
new_node = ast.fix_missing_locations(
174175
set_lineno(new_node, last_statement.end_lineno)
175176
)

projects/eudsl-python-extras/mlir/extras/dialects/func.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from functools import update_wrapper
88
from typing import Optional, List, Union, TypeVar
99

10-
from ..ast.util import copy_func
10+
from .. import types as T
1111
from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject
12+
from ..ast.util import copy_func
1213
from ..meta import op_region_builder
13-
from .. import types as T
1414
from ..util import get_user_code_loc, make_maybe_no_args_decorator
1515
from ...dialects._ods_common import get_op_result_or_op_results
1616
from ...dialects.func import *
@@ -26,7 +26,6 @@
2626
Value,
2727
)
2828

29-
3029
_call = call
3130

3231

projects/eudsl-python-extras/mlir/extras/dialects/gpu.py

Lines changed: 81 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import partial
66
from typing import Any, List, Optional, Tuple, Union
77

8-
98
from .func import FuncBase
109
from .. import types as T
1110
from ..meta import (
@@ -25,6 +24,10 @@
2524
get_op_result_or_op_results,
2625
)
2726
from ...dialects.gpu import *
27+
28+
del constant
29+
# constant needs to be below gpu import because it needs to shadow upstream's arith.constant
30+
# noinspection PyUnusedImports
2831
from .arith import constant
2932
from ...ir import (
3033
ArrayAttr,
@@ -108,9 +111,9 @@ def z(cls):
108111

109112
def thread_id():
110113
return (
111-
block_dim.x * block_dim.y * thread_idx.z
112-
+ block_dim.x * thread_idx.y
113-
+ thread_idx.x
114+
block_dim.x * block_dim.y * thread_idx.z
115+
+ block_dim.x * thread_idx.y
116+
+ thread_idx.x
114117
)
115118

116119

@@ -126,7 +129,7 @@ def set_container_module(module):
126129

127130
@register_attribute_builder("DeviceMappingArrayAttr")
128131
def get_device_mapping_array_attr(
129-
mapping: List[Attribute], context: Optional[Context] = None
132+
mapping: List[Attribute], context: Optional[Context] = None
130133
) -> ArrayAttr:
131134
if context is None:
132135
context = Context.current
@@ -177,7 +180,7 @@ def smem_space(int=False):
177180
@_cext.register_operation(_Dialect, replace=True)
178181
class GPUModuleOp(GPUModuleOp):
179182
def __init__(
180-
self, sym_name, targets: Optional[List[Attribute]] = None, *, loc=None, ip=None
183+
self, sym_name, targets: Optional[List[Attribute]] = None, *, loc=None, ip=None
181184
):
182185
if targets is None:
183186
targets = []
@@ -188,8 +191,8 @@ def __init__(
188191
sym_name = (
189192
sym_name
190193
if (
191-
issubclass(type(sym_name), Attribute)
192-
or not AttrBuilder.contains("SymbolNameAttr")
194+
issubclass(type(sym_name), Attribute)
195+
or not AttrBuilder.contains("SymbolNameAttr")
193196
)
194197
else AttrBuilder.get("SymbolNameAttr")(sym_name, context=_ods_context)
195198
)
@@ -227,17 +230,17 @@ def __prepare__(cls, name, bases, **kwargs):
227230

228231
class GPUFuncOp(GPUFuncOp_):
229232
def __init__(
230-
self,
231-
sym_name,
232-
function_type,
233-
*,
234-
sym_visibility=None,
235-
arg_attrs=None,
236-
res_attrs=None,
237-
workgroup_attrib_attrs=None,
238-
private_attrib_attrs=None,
239-
loc=None,
240-
ip=None,
233+
self,
234+
sym_name,
235+
function_type,
236+
*,
237+
sym_visibility=None,
238+
arg_attrs=None,
239+
res_attrs=None,
240+
workgroup_attrib_attrs=None,
241+
private_attrib_attrs=None,
242+
loc=None,
243+
ip=None,
241244
):
242245
super().__init__(
243246
function_type=function_type,
@@ -255,17 +258,17 @@ def __init__(
255258
self.operation.attributes["arg_attrs"] = (
256259
arg_attrs
257260
if (
258-
isinstance(arg_attrs, Attribute)
259-
or not AttrBuilder.contains("DictArrayAttr")
261+
isinstance(arg_attrs, Attribute)
262+
or not AttrBuilder.contains("DictArrayAttr")
260263
)
261264
else AttrBuilder.get("DictArrayAttr")(arg_attrs, context=_ods_context)
262265
)
263266
if res_attrs is not None:
264267
self.operation.attributes["res_attrs"] = (
265268
res_attrs
266269
if (
267-
isinstance(res_attrs, Attribute)
268-
or not AttrBuilder.contains("DictArrayAttr")
270+
isinstance(res_attrs, Attribute)
271+
or not AttrBuilder.contains("DictArrayAttr")
269272
)
270273
else AttrBuilder.get("DictArrayAttr")(res_attrs, context=_ods_context)
271274
)
@@ -274,23 +277,23 @@ def __init__(
274277
self.operation.attributes["sym_visibility"] = (
275278
sym_visibility
276279
if (
277-
issubclass(type(sym_visibility), Attribute)
278-
or not AttrBuilder.contains("StrAttr")
280+
issubclass(type(sym_visibility), Attribute)
281+
or not AttrBuilder.contains("StrAttr")
279282
)
280283
else AttrBuilder.get("StrAttr")(sym_visibility, context=_ods_context)
281284
)
282285

283286

284287
class LaunchOp(LaunchOp):
285288
def __init__(
286-
self,
287-
grid_size: Tuple[Any, Any, Any],
288-
block_size: Tuple[Any, Any, Any],
289-
async_dependencies=None,
290-
dynamic_shared_memory_size: Optional[Value] = None,
291-
*,
292-
loc=None,
293-
ip=None,
289+
self,
290+
grid_size: Tuple[Any, Any, Any],
291+
block_size: Tuple[Any, Any, Any],
292+
async_dependencies=None,
293+
dynamic_shared_memory_size: Optional[Value] = None,
294+
*,
295+
loc=None,
296+
ip=None,
294297
):
295298
_ods_context = get_default_loc_context(loc)
296299
if async_dependencies is None:
@@ -309,13 +312,13 @@ def __init__(
309312

310313

311314
def launch_(
312-
grid_size: Tuple[Any, Any, Any],
313-
block_size: Tuple[Any, Any, Any],
314-
async_dependencies=None,
315-
dynamic_shared_memory_size: Optional[Value] = None,
316-
*,
317-
loc=None,
318-
ip=None,
315+
grid_size: Tuple[Any, Any, Any],
316+
block_size: Tuple[Any, Any, Any],
317+
async_dependencies=None,
318+
dynamic_shared_memory_size: Optional[Value] = None,
319+
*,
320+
loc=None,
321+
ip=None,
319322
):
320323
for size in [grid_size, block_size]:
321324
for i, s in enumerate(size):
@@ -337,17 +340,17 @@ def launch_(
337340

338341
class LaunchFuncOp(LaunchFuncOp):
339342
def __init__(
340-
self,
341-
kernel: List[str],
342-
grid_size: Tuple[Any, Any, Any],
343-
block_size: Tuple[Any, Any, Any],
344-
kernel_operands: List[Value] = None,
345-
async_dependencies=None,
346-
dynamic_shared_memory_size: Optional[Value] = None,
347-
async_object=None,
348-
*,
349-
loc=None,
350-
ip=None,
343+
self,
344+
kernel: List[str],
345+
grid_size: Tuple[Any, Any, Any],
346+
block_size: Tuple[Any, Any, Any],
347+
kernel_operands: List[Value] = None,
348+
async_dependencies=None,
349+
dynamic_shared_memory_size: Optional[Value] = None,
350+
async_object=None,
351+
*,
352+
loc=None,
353+
ip=None,
351354
):
352355
_ods_context = get_default_loc_context(loc)
353356
if async_dependencies is None:
@@ -370,15 +373,15 @@ def __init__(
370373

371374
class GPUFunc(FuncBase):
372375
def __call__(
373-
self,
374-
*kernel_operands: List[Value],
375-
grid_size: Tuple[Any, Any, Any],
376-
block_size: Tuple[Any, Any, Any],
377-
async_dependencies=None,
378-
dynamic_shared_memory_size: Optional[Value] = None,
379-
stream=None,
380-
loc=None,
381-
ip=None,
376+
self,
377+
*kernel_operands: List[Value],
378+
grid_size: Tuple[Any, Any, Any],
379+
block_size: Tuple[Any, Any, Any],
380+
async_dependencies=None,
381+
dynamic_shared_memory_size: Optional[Value] = None,
382+
stream=None,
383+
loc=None,
384+
ip=None,
382385
):
383386
for size in [grid_size, block_size]:
384387
for i, s in enumerate(size):
@@ -432,16 +435,16 @@ def __call__(self, *args, **kwargs):
432435

433436
@make_maybe_no_args_decorator
434437
def func(
435-
f,
436-
*,
437-
sym_visibility=None,
438-
arg_attrs=None,
439-
res_attrs=None,
440-
func_attrs=None,
441-
emit=False,
442-
loc=None,
443-
ip=None,
444-
emit_grid=False,
438+
f,
439+
*,
440+
sym_visibility=None,
441+
arg_attrs=None,
442+
res_attrs=None,
443+
func_attrs=None,
444+
emit=False,
445+
loc=None,
446+
ip=None,
447+
emit_grid=False,
445448
) -> Grid:
446449
func_ = GPUFunc(
447450
body_builder=f,
@@ -490,14 +493,14 @@ def wait(async_dependencies: Optional[List[Value]] = None, *, loc=None, ip=None)
490493

491494

492495
def alloc(
493-
sizes: Union[int, Value],
494-
element_type: Type = None,
495-
async_dependencies=None,
496-
dynamic_sizes=None,
497-
symbol_operands=None,
498-
host_shared=None,
499-
loc=None,
500-
ip=None,
496+
sizes: Union[int, Value],
497+
element_type: Type = None,
498+
async_dependencies=None,
499+
dynamic_sizes=None,
500+
symbol_operands=None,
501+
host_shared=None,
502+
loc=None,
503+
ip=None,
501504
):
502505
if symbol_operands is None:
503506
symbol_operands = []

0 commit comments

Comments
 (0)