Skip to content

Commit a32cf72

Browse files
committed
[eudsl-python-extras] better handle callable generics
1 parent 34fbab8 commit a32cf72

File tree

3 files changed

+166
-6
lines changed

3 files changed

+166
-6
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/func.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(
179179
self.call_op_ctor = call_op_ctor
180180
self.arg_attrs = arg_attrs
181181
self.res_attrs = res_attrs
182-
self.generics = generics
182+
self.generics = copy.deepcopy(generics)
183183
self.loc = loc
184184
self.ip = ip
185185
self._func_op = None
@@ -379,18 +379,21 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
379379
)
380380
body_builder.__closure__[free_i].cell_contents = r.val
381381

382+
name_mangled_generics = []
383+
for r in reified_type_params:
384+
t, v = r.type, r.val
385+
if callable(v):
386+
v = v.__name__
387+
name_mangled_generics.append(f"{t}_{v}")
388+
382389
return FuncBase(
383390
body_builder,
384391
self.func_op_ctor,
385392
self.return_op_ctor,
386393
self.call_op_ctor,
387394
return_types=self.return_types,
388395
sym_visibility=self.sym_visibility,
389-
sym_name=(
390-
self.func_name
391-
+ "_"
392-
+ "_".join([f"{r.type}_{r.val}" for r in reified_type_params])
393-
),
396+
sym_name=(self.func_name + "_" + "_".join(name_mangled_generics)),
394397
arg_attrs=self.arg_attrs,
395398
res_attrs=self.res_attrs,
396399
func_attrs=self.func_attrs,

projects/eudsl-python-extras/tests/dialect/test_func.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,52 @@ def mat_product_kernel(
205205
one = arith.constant(1, dtype)
206206

207207
mat_product_kernel[32, 32, 32, T.i32()].emit()
208+
mat_product_kernel[32, 32, 32, T.f32()].emit()
208209

209210
# CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_i32(%[[VAL_0:.*]]: memref<32x32xi32>, %[[VAL_1:.*]]: memref<32x32xi32>, %[[VAL_2:.*]]: memref<32x32xi32>) {
210211
# CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
211212
# CHECK: return
212213
# CHECK: }
214+
# CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
215+
# CHECK: %cst = arith.constant 1.000000e+00 : f32
216+
# CHECK: return
217+
# CHECK: }
218+
219+
filecheck_with_comments(ctx.module)
220+
221+
222+
def test_generics_callable(ctx: MLIRContext):
223+
_op = TypeVar("_op")
224+
225+
@func(generics=[_op])
226+
def mat_product_kernel1():
227+
one = arith.constant(1, T.f32())
228+
two = _op(one, one)
229+
230+
@func(generics=[_op])
231+
def mat_product_kernel2():
232+
one = arith.constant(1, T.f32())
233+
two = _op(one, one)
234+
235+
mat_product_kernel1[arith.maximumf,].emit()
236+
mat_product_kernel2[arith.minimumf,].emit()
237+
mat_product_kernel2[arith.maximumf,].emit()
238+
239+
# CHECK: func.func @mat_product_kernel1_function_maximumf() {
240+
# CHECK: %cst = arith.constant 1.000000e+00 : f32
241+
# CHECK: %0 = arith.maximumf %cst, %cst : f32
242+
# CHECK: return
243+
# CHECK: }
244+
# CHECK: func.func @mat_product_kernel2_function_minimumf() {
245+
# CHECK: %cst = arith.constant 1.000000e+00 : f32
246+
# CHECK: %0 = arith.minimumf %cst, %cst : f32
247+
# CHECK: return
248+
# CHECK: }
249+
# CHECK: func.func @mat_product_kernel2_function_maximumf() {
250+
# CHECK: %cst = arith.constant 1.000000e+00 : f32
251+
# CHECK: %0 = arith.maximumf %cst, %cst : f32
252+
# CHECK: return
253+
# CHECK: }
213254

214255
filecheck_with_comments(ctx.module)
215256

projects/eudsl-python-extras/tests/dialect/test_linalg.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
filecheck_with_comments,
1919
mlir_ctx as ctx,
2020
)
21+
from mlir.extras.runtime.passes import Pipeline, run_pipeline
2122

2223
# needed since the fix isn't defined here nor conftest.py
2324
pytest.mark.usefixtures("ctx")
@@ -134,3 +135,118 @@ def maxpool3d(
134135
# CHECK: return
135136
# CHECK: }
136137
filecheck_with_comments(maxpool3d_k)
138+
139+
140+
def test_pooling_ncdhw_max_parallel(ctx: MLIRContext):
141+
S = ShapedType.get_dynamic_size()
142+
143+
generics = (
144+
kernel_size_0,
145+
kernel_size_1,
146+
kernel_size_2,
147+
stride_0,
148+
stride_1,
149+
stride_2,
150+
dilation_0,
151+
dilation_1,
152+
dilation_2,
153+
) = list(
154+
map(
155+
TypeVar,
156+
[
157+
"kernel_size_0",
158+
"kernel_size_1",
159+
"kernel_size_2",
160+
"stride_0",
161+
"stride_1",
162+
"stride_2",
163+
"dilation_0",
164+
"dilation_1",
165+
"dilation_2",
166+
],
167+
)
168+
)
169+
170+
@func(
171+
generics=(
172+
kernel_size_0,
173+
kernel_size_1,
174+
kernel_size_2,
175+
stride_0,
176+
stride_1,
177+
stride_2,
178+
dilation_0,
179+
dilation_1,
180+
dilation_2,
181+
)
182+
)
183+
def maxpool3d(
184+
input: T.memref(S, S, S, S, S, T.f32()),
185+
output: T.memref(S, S, S, S, S, T.f32()),
186+
):
187+
kernel_shape_surrogate = memref.alloca(
188+
(kernel_size_0, kernel_size_1, kernel_size_2),
189+
T.f32(),
190+
)
191+
192+
linalg.pooling_ncdhw_max(
193+
input,
194+
kernel_shape_surrogate,
195+
output,
196+
strides=[stride_0, stride_1, stride_2],
197+
dilations=[dilation_0, dilation_1, dilation_2],
198+
)
199+
200+
kernel_sizes = [1, 2, 3]
201+
strides = [4, 5, 6]
202+
dilations = [7, 8, 9]
203+
maxpool3d_k = maxpool3d[
204+
kernel_sizes[0],
205+
kernel_sizes[1],
206+
kernel_sizes[2],
207+
strides[0],
208+
strides[1],
209+
strides[2],
210+
dilations[0],
211+
dilations[1],
212+
dilations[2],
213+
].emit()
214+
module = run_pipeline(
215+
ctx.module,
216+
Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()),
217+
)
218+
# CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1 * 7)>
219+
# CHECK: #map1 = affine_map<(d0, d1) -> (d0 * 5 + d1 * 8)>
220+
# CHECK: #map2 = affine_map<(d0, d1) -> (d0 * 6 + d1 * 9)>
221+
# CHECK: module {
222+
# CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_4_int_5_int_6_int_7_int_8_int_9(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>) {
223+
# CHECK: %c4 = arith.constant 4 : index
224+
# CHECK: %c3 = arith.constant 3 : index
225+
# CHECK: %c2 = arith.constant 2 : index
226+
# CHECK: %c1 = arith.constant 1 : index
227+
# CHECK: %c0 = arith.constant 0 : index
228+
# CHECK: %dim = memref.dim %arg0, %c0 : memref<?x?x?x?x?xf32>
229+
# CHECK: %dim_0 = memref.dim %arg0, %c1 : memref<?x?x?x?x?xf32>
230+
# CHECK: %dim_1 = memref.dim %arg1, %c2 : memref<?x?x?x?x?xf32>
231+
# CHECK: %dim_2 = memref.dim %arg1, %c3 : memref<?x?x?x?x?xf32>
232+
# CHECK: %dim_3 = memref.dim %arg1, %c4 : memref<?x?x?x?x?xf32>
233+
# CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2, %dim_3) step (%c1, %c1, %c1, %c1, %c1) {
234+
# CHECK: scf.for %arg7 = %c0 to %c1 step %c1 {
235+
# CHECK: scf.for %arg8 = %c0 to %c2 step %c1 {
236+
# CHECK: scf.for %arg9 = %c0 to %c3 step %c1 {
237+
# CHECK: %0 = affine.apply #map(%arg4, %arg7)
238+
# CHECK: %1 = affine.apply #map1(%arg5, %arg8)
239+
# CHECK: %2 = affine.apply #map2(%arg6, %arg9)
240+
# CHECK: %3 = memref.load %arg0[%arg2, %arg3, %0, %1, %2] : memref<?x?x?x?x?xf32>
241+
# CHECK: %4 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref<?x?x?x?x?xf32>
242+
# CHECK: %5 = arith.maximumf %3, %4 : f32
243+
# CHECK: memref.store %5, %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref<?x?x?x?x?xf32>
244+
# CHECK: }
245+
# CHECK: }
246+
# CHECK: }
247+
# CHECK: scf.reduce
248+
# CHECK: }
249+
# CHECK: return
250+
# CHECK: }
251+
# CHECK: }
252+
filecheck_with_comments(module)

0 commit comments

Comments
 (0)