Skip to content

Commit df43bf1

Browse files
Correctly handle comparisons with index types (#234)
Comparisons with index types was failing because of an unguarded check of `dtype.is_signed`, which is not present on ir.IndexType.
1 parent 44daf43 commit df43bf1

File tree

2 files changed

+118
-49
lines changed

2 files changed

+118
-49
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/arith.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,11 @@ def _binary_op(
371371
if signedness is not None:
372372
predicate = signedness + predicate
373373
else:
374-
if lhs.dtype.is_signed or lhs.dtype.is_signless:
375-
predicate = "s" + predicate
376-
else:
377-
assert lhs.dtype.is_unsigned
374+
if _is_index_type(lhs.dtype) or lhs.dtype.is_unsigned:
378375
predicate = "u" + predicate
376+
else:
377+
assert lhs.dtype.is_signed or lhs.dtype.is_signless
378+
predicate = "s" + predicate
379379
return lhs.__class__(op(predicate, lhs, rhs, loc=loc), dtype=lhs.dtype)
380380
else:
381381
return lhs.__class__(op(lhs, rhs, loc=loc), dtype=lhs.dtype)

projects/eudsl-python-extras/tests/dialect/test_arith.py

Lines changed: 114 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def test_arithmetic(ctx: MLIRContext):
6363
one // two
6464
one % two
6565

66+
two = arith.constant(2, index=True)
67+
one + two
68+
one - two
69+
one / two
70+
one // two
71+
one % two
72+
6673
one = arith.constant(1.0)
6774
two = arith.constant(2.0)
6875
one + two
@@ -79,19 +86,31 @@ def test_arithmetic(ctx: MLIRContext):
7986

8087
ctx.module.operation.verify()
8188

82-
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
83-
# CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32
84-
# CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : i32
85-
# CHECK: %[[VAL_3:.*]] = arith.subi %[[VAL_0]], %[[VAL_1]] : i32
86-
# CHECK: %[[VAL_4:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : i32
87-
# CHECK: %[[VAL_5:.*]] = arith.floordivsi %[[VAL_0]], %[[VAL_1]] : i32
88-
# CHECK: %[[VAL_6:.*]] = arith.remsi %[[VAL_0]], %[[VAL_1]] : i32
89-
# CHECK: %[[VAL_7:.*]] = arith.constant 1.000000e+00 : f32
90-
# CHECK: %[[VAL_8:.*]] = arith.constant 2.000000e+00 : f32
91-
# CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
92-
# CHECK: %[[VAL_10:.*]] = arith.subf %[[VAL_7]], %[[VAL_8]] : f32
93-
# CHECK: %[[VAL_11:.*]] = arith.divf %[[VAL_7]], %[[VAL_8]] : f32
94-
# CHECK: %[[VAL_12:.*]] = arith.remf %[[VAL_7]], %[[VAL_8]] : f32
89+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
90+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 2 : i32
91+
# CHECK: %[[ADDI_0:.*]] = arith.addi %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
92+
# CHECK: %[[SUBI_0:.*]] = arith.subi %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
93+
# CHECK: %[[DIVSI_0:.*]] = arith.divsi %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
94+
# CHECK: %[[FLOORDIVSI_0:.*]] = arith.floordivsi %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
95+
# CHECK: %[[REMSI_0:.*]] = arith.remsi %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
96+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 2 : index
97+
# CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[CONSTANT_2]] : index to i32
98+
# CHECK: %[[ADDI_1:.*]] = arith.addi %[[CONSTANT_0]], %[[INDEX_CAST_0]] : i32
99+
# CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[CONSTANT_2]] : index to i32
100+
# CHECK: %[[SUBI_1:.*]] = arith.subi %[[CONSTANT_0]], %[[INDEX_CAST_1]] : i32
101+
# CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[CONSTANT_2]] : index to i32
102+
# CHECK: %[[DIVSI_1:.*]] = arith.divsi %[[CONSTANT_0]], %[[INDEX_CAST_2]] : i32
103+
# CHECK: %[[INDEX_CAST_3:.*]] = arith.index_cast %[[CONSTANT_2]] : index to i32
104+
# CHECK: %[[FLOORDIVSI_1:.*]] = arith.floordivsi %[[CONSTANT_0]], %[[INDEX_CAST_3]] : i32
105+
# CHECK: %[[INDEX_CAST_4:.*]] = arith.index_cast %[[CONSTANT_2]] : index to i32
106+
# CHECK: %[[REMSI_1:.*]] = arith.remsi %[[CONSTANT_0]], %[[INDEX_CAST_4]] : i32
107+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 1.000000e+00 : f32
108+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 2.000000e+00 : f32
109+
# CHECK: %[[ADDF_0:.*]] = arith.addf %[[CONSTANT_3]], %[[CONSTANT_4]] : f32
110+
# CHECK: %[[SUBF_0:.*]] = arith.subf %[[CONSTANT_3]], %[[CONSTANT_4]] : f32
111+
# CHECK: %[[DIVF_0:.*]] = arith.divf %[[CONSTANT_3]], %[[CONSTANT_4]] : f32
112+
# CHECK: %[[REMF_0:.*]] = arith.remf %[[CONSTANT_3]], %[[CONSTANT_4]] : f32
113+
95114
filecheck_with_comments(ctx.module)
96115

97116

@@ -111,19 +130,80 @@ def test_r_arithmetic(ctx: MLIRContext):
111130
filecheck_with_comments(ctx.module)
112131

113132

114-
def test_arith_cmp(ctx: MLIRContext):
115-
one = arith.constant(1)
116-
two = arith.constant(2)
117-
one < two
118-
one <= two
119-
one > two
120-
one >= two
121-
one == two
122-
one != two
123-
one & two
124-
one | two
125-
assert one._ne(two)
126-
assert not one._eq(two)
133+
def test_arith_cmpi(ctx: MLIRContext):
134+
for kind1, kind2 in [({}, {}), ({'index': True}, {'index': True}), ({'index': True}, {}), ({}, {'index': True})]:
135+
one = arith.constant(1, **kind1)
136+
two = arith.constant(2, **kind2)
137+
one < two
138+
one <= two
139+
one > two
140+
one >= two
141+
one == two
142+
one != two
143+
one & two
144+
one | two
145+
assert one._ne(two)
146+
assert not one._eq(two)
147+
148+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
149+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 2 : i32
150+
# CHECK: %[[CMPI_0:.*]] = arith.cmpi slt, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
151+
# CHECK: %[[CMPI_1:.*]] = arith.cmpi sle, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
152+
# CHECK: %[[CMPI_2:.*]] = arith.cmpi sgt, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
153+
# CHECK: %[[CMPI_3:.*]] = arith.cmpi sge, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
154+
# CHECK: %[[CMPI_4:.*]] = arith.cmpi eq, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
155+
# CHECK: %[[CMPI_5:.*]] = arith.cmpi ne, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
156+
# CHECK: %[[ANDI_0:.*]] = arith.andi %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
157+
# CHECK: %[[ORI_0:.*]] = arith.ori %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
158+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
159+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 2 : index
160+
# CHECK: %[[CMPI_6:.*]] = arith.cmpi ult, %[[CONSTANT_2]], %[[CONSTANT_3]] : index
161+
# CHECK: %[[CMPI_7:.*]] = arith.cmpi ule, %[[CONSTANT_2]], %[[CONSTANT_3]] : index
162+
# CHECK: %[[CMPI_8:.*]] = arith.cmpi ugt, %[[CONSTANT_2]], %[[CONSTANT_3]] : index
163+
# CHECK: %[[CMPI_9:.*]] = arith.cmpi uge, %[[CONSTANT_2]], %[[CONSTANT_3]] : index
164+
# CHECK: %[[CMPI_10:.*]] = arith.cmpi eq, %[[CONSTANT_2]], %[[CONSTANT_3]] : index
165+
# CHECK: %[[CMPI_11:.*]] = arith.cmpi ne, %[[CONSTANT_2]], %[[CONSTANT_3]] : index
166+
# CHECK: %[[ANDI_1:.*]] = arith.andi %[[CONSTANT_2]], %[[CONSTANT_3]] : index
167+
# CHECK: %[[ORI_1:.*]] = arith.ori %[[CONSTANT_2]], %[[CONSTANT_3]] : index
168+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : index
169+
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
170+
# CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
171+
# CHECK: %[[CMPI_12:.*]] = arith.cmpi ult, %[[CONSTANT_4]], %[[INDEX_CAST_0]] : index
172+
# CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
173+
# CHECK: %[[CMPI_13:.*]] = arith.cmpi ule, %[[CONSTANT_4]], %[[INDEX_CAST_1]] : index
174+
# CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
175+
# CHECK: %[[CMPI_14:.*]] = arith.cmpi ugt, %[[CONSTANT_4]], %[[INDEX_CAST_2]] : index
176+
# CHECK: %[[INDEX_CAST_3:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
177+
# CHECK: %[[CMPI_15:.*]] = arith.cmpi uge, %[[CONSTANT_4]], %[[INDEX_CAST_3]] : index
178+
# CHECK: %[[INDEX_CAST_4:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
179+
# CHECK: %[[CMPI_16:.*]] = arith.cmpi eq, %[[CONSTANT_4]], %[[INDEX_CAST_4]] : index
180+
# CHECK: %[[INDEX_CAST_5:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
181+
# CHECK: %[[CMPI_17:.*]] = arith.cmpi ne, %[[CONSTANT_4]], %[[INDEX_CAST_5]] : index
182+
# CHECK: %[[INDEX_CAST_6:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
183+
# CHECK: %[[ANDI_2:.*]] = arith.andi %[[CONSTANT_4]], %[[INDEX_CAST_6]] : index
184+
# CHECK: %[[INDEX_CAST_7:.*]] = arith.index_cast %[[CONSTANT_5]] : i32 to index
185+
# CHECK: %[[ORI_2:.*]] = arith.ori %[[CONSTANT_4]], %[[INDEX_CAST_7]] : index
186+
# CHECK: %[[CONSTANT_6:.*]] = arith.constant 1 : i32
187+
# CHECK: %[[CONSTANT_7:.*]] = arith.constant 2 : index
188+
# CHECK: %[[INDEX_CAST_8:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
189+
# CHECK: %[[CMPI_18:.*]] = arith.cmpi slt, %[[CONSTANT_6]], %[[INDEX_CAST_8]] : i32
190+
# CHECK: %[[INDEX_CAST_9:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
191+
# CHECK: %[[CMPI_19:.*]] = arith.cmpi sle, %[[CONSTANT_6]], %[[INDEX_CAST_9]] : i32
192+
# CHECK: %[[INDEX_CAST_10:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
193+
# CHECK: %[[CMPI_20:.*]] = arith.cmpi sgt, %[[CONSTANT_6]], %[[INDEX_CAST_10]] : i32
194+
# CHECK: %[[INDEX_CAST_11:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
195+
# CHECK: %[[CMPI_21:.*]] = arith.cmpi sge, %[[CONSTANT_6]], %[[INDEX_CAST_11]] : i32
196+
# CHECK: %[[INDEX_CAST_12:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
197+
# CHECK: %[[CMPI_22:.*]] = arith.cmpi eq, %[[CONSTANT_6]], %[[INDEX_CAST_12]] : i32
198+
# CHECK: %[[INDEX_CAST_13:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
199+
# CHECK: %[[CMPI_23:.*]] = arith.cmpi ne, %[[CONSTANT_6]], %[[INDEX_CAST_13]] : i32
200+
# CHECK: %[[INDEX_CAST_14:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
201+
# CHECK: %[[ANDI_3:.*]] = arith.andi %[[CONSTANT_6]], %[[INDEX_CAST_14]] : i32
202+
# CHECK: %[[INDEX_CAST_15:.*]] = arith.index_cast %[[CONSTANT_7]] : index to i32
203+
# CHECK: %[[ORI_3:.*]] = arith.ori %[[CONSTANT_6]], %[[INDEX_CAST_15]] : i32
204+
filecheck_with_comments(ctx.module)
205+
206+
def test_arith_cmpf(ctx: MLIRContext):
127207

128208
one = arith.constant(1.0)
129209
two = arith.constant(2.0)
@@ -138,25 +218,14 @@ def test_arith_cmp(ctx: MLIRContext):
138218

139219
ctx.module.operation.verify()
140220

141-
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
142-
# CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32
143-
# CHECK: %[[VAL_2:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_1]] : i32
144-
# CHECK: %[[VAL_3:.*]] = arith.cmpi sle, %[[VAL_0]], %[[VAL_1]] : i32
145-
# CHECK: %[[VAL_4:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_1]] : i32
146-
# CHECK: %[[VAL_5:.*]] = arith.cmpi sge, %[[VAL_0]], %[[VAL_1]] : i32
147-
# CHECK: %[[VAL_6:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_1]] : i32
148-
# CHECK: %[[VAL_7:.*]] = arith.cmpi ne, %[[VAL_0]], %[[VAL_1]] : i32
149-
# CHECK: %[[VAL_8:.*]] = arith.andi %[[VAL_0]], %[[VAL_1]] : i32
150-
# CHECK: %[[VAL_9:.*]] = arith.ori %[[VAL_0]], %[[VAL_1]] : i32
151-
# CHECK: %[[VAL_10:.*]] = arith.constant 1.000000e+00 : f32
152-
# CHECK: %[[VAL_11:.*]] = arith.constant 2.000000e+00 : f32
153-
# CHECK: %[[VAL_12:.*]] = arith.cmpf olt, %[[VAL_10]], %[[VAL_11]] : f32
154-
# CHECK: %[[VAL_13:.*]] = arith.cmpf ole, %[[VAL_10]], %[[VAL_11]] : f32
155-
# CHECK: %[[VAL_14:.*]] = arith.cmpf ogt, %[[VAL_10]], %[[VAL_11]] : f32
156-
# CHECK: %[[VAL_15:.*]] = arith.cmpf oge, %[[VAL_10]], %[[VAL_11]] : f32
157-
# CHECK: %[[VAL_16:.*]] = arith.cmpf oeq, %[[VAL_10]], %[[VAL_11]] : f32
158-
# CHECK: %[[VAL_17:.*]] = arith.cmpf one, %[[VAL_10]], %[[VAL_11]] : f32
159-
221+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32
222+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 2.000000e+00 : f32
223+
# CHECK: %[[CMPF_0:.*]] = arith.cmpf olt, %[[CONSTANT_0]], %[[CONSTANT_1]] : f32
224+
# CHECK: %[[CMPF_1:.*]] = arith.cmpf ole, %[[CONSTANT_0]], %[[CONSTANT_1]] : f32
225+
# CHECK: %[[CMPF_2:.*]] = arith.cmpf ogt, %[[CONSTANT_0]], %[[CONSTANT_1]] : f32
226+
# CHECK: %[[CMPF_3:.*]] = arith.cmpf oge, %[[CONSTANT_0]], %[[CONSTANT_1]] : f32
227+
# CHECK: %[[CMPF_4:.*]] = arith.cmpf oeq, %[[CONSTANT_0]], %[[CONSTANT_1]] : f32
228+
# CHECK: %[[CMPF_5:.*]] = arith.cmpf one, %[[CONSTANT_0]], %[[CONSTANT_1]] : f32
160229
filecheck_with_comments(ctx.module)
161230

162231

0 commit comments

Comments
 (0)