Skip to content

Commit 224bc7d

Browse files
authored
[HWLegalizeModules] transform mux of array to muxes of elements (#9230)
* When disallowPackedArrays is set, transform muxes of arrays to muxes of elements. See the added tests for examples. * Some minor refactoring.
1 parent f63ea4d commit 224bc7d

File tree

2 files changed

+144
-53
lines changed

2 files changed

+144
-53
lines changed

lib/Dialect/SV/Transforms/HWLegalizeModules.cpp

Lines changed: 72 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616

17+
#include "circt/Dialect/Comb/CombOps.h"
1718
#include "circt/Dialect/HW/HWOps.h"
1819
#include "circt/Dialect/HW/HWTypes.h"
1920
#include "circt/Dialect/SV/SVOps.h"
@@ -42,6 +43,10 @@ struct HWLegalizeModulesPass
4243
private:
4344
void processPostOrder(Block &block);
4445
bool tryLoweringPackedArrayOp(Operation &op);
46+
template <typename ElementType>
47+
SmallVector<std::pair<Value, Value>>
48+
createIndexValuePairs(OpBuilder &builder, LocationAttr loc, hw::ArrayType ty,
49+
Value array);
4550
Value lowerLookupToCasez(Operation &op, Value input, Value index,
4651
mlir::Type elementType,
4752
SmallVector<Value> caseValues);
@@ -89,43 +94,20 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
8994
})
9095
.Case<hw::ArrayConcatOp>([&](hw::ArrayConcatOp concatOp) {
9196
// Redirect individual element uses (if any) to the input arguments.
92-
SmallVector<std::pair<Value, uint64_t>> arrays;
97+
SmallVector<Value> values;
98+
OpBuilder builder(concatOp);
9399
for (auto array : llvm::reverse(concatOp.getInputs())) {
94100
auto ty = hw::type_cast<hw::ArrayType>(array.getType());
95-
arrays.emplace_back(array, ty.getNumElements());
96-
}
97-
for (auto *user :
98-
llvm::make_early_inc_range(concatOp.getResult().getUsers())) {
99-
if (TypeSwitch<Operation *, bool>(user)
100-
.Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
101-
if (auto indexAndBitWidth =
102-
tryExtractIndexAndBitWidth(getOp.getIndex())) {
103-
auto [indexValue, bitWidth] = *indexAndBitWidth;
104-
// FIXME: More efficient search
105-
for (const auto &[array, size] : arrays) {
106-
if (indexValue >= size) {
107-
indexValue -= size;
108-
continue;
109-
}
110-
OpBuilder builder(getOp);
111-
getOp.getInputMutable().set(array);
112-
getOp.getIndexMutable().set(
113-
builder.createOrFold<hw::ConstantOp>(
114-
getOp.getLoc(), APInt(bitWidth, indexValue)));
115-
return true;
116-
}
117-
}
118-
119-
return false;
120-
})
121-
.Default([](auto op) { return false; }))
122-
continue;
123-
124-
op.emitError("unsupported packed array expression");
125-
signalPassFailure();
101+
const auto indexValues = createIndexValuePairs<hw::ArrayGetOp>(
102+
builder, concatOp.getLoc(), ty, array);
103+
for (const auto &[_, value] : indexValues) {
104+
values.push_back(value);
105+
}
126106
}
107+
if (!processUsers(op, concatOp.getResult(), values))
108+
return false;
127109

128-
// Remove the original op.
110+
// Remove original op.
129111
return true;
130112
})
131113
.Case<hw::ArrayCreateOp>([&](hw::ArrayCreateOp createOp) {
@@ -147,15 +129,12 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
147129
// Generate case value element lookups.
148130
auto ty = hw::type_cast<hw::ArrayType>(getOp.getInput().getType());
149131
OpBuilder builder(getOp);
132+
auto loc = op.getLoc();
133+
const auto indexValues = createIndexValuePairs<hw::ArrayGetOp>(
134+
builder, loc, ty, getOp.getInput());
150135
SmallVector<Value> caseValues;
151-
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
152-
auto loc = op.getLoc();
153-
auto index = builder.createOrFold<hw::ConstantOp>(
154-
loc, APInt(llvm::Log2_64_Ceil(e), i));
155-
auto element =
156-
hw::ArrayGetOp::create(builder, loc, getOp.getInput(), index);
157-
caseValues.push_back(element);
158-
}
136+
for (const auto &[_, value] : indexValues)
137+
caseValues.push_back(value);
159138

160139
// Transform array index op into casez statement.
161140
auto theWire = lowerLookupToCasez(op, getOp.getInput(), index,
@@ -183,13 +162,11 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
183162
// Generate case value element lookups.
184163
auto ty = hw::type_cast<hw::ArrayType>(inout.getElementType());
185164
OpBuilder builder(&op);
165+
auto loc = op.getLoc();
166+
const auto indexValues = createIndexValuePairs<sv::ArrayIndexInOutOp>(
167+
builder, loc, ty, indexOp.getInput());
186168
SmallVector<Value> caseValues;
187-
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
188-
auto loc = op.getLoc();
189-
auto index = builder.createOrFold<hw::ConstantOp>(
190-
loc, APInt(llvm::Log2_64_Ceil(e), i));
191-
auto element = sv::ArrayIndexInOutOp::create(
192-
builder, loc, indexOp.getInput(), index);
169+
for (const auto &[_, element] : indexValues) {
193170
auto readElement = sv::ReadInOutOp::create(builder, loc, element);
194171
caseValues.push_back(readElement);
195172
}
@@ -211,14 +188,12 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
211188
return false;
212189

213190
OpBuilder builder(assignOp);
214-
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
215-
auto loc = op.getLoc();
216-
auto index = builder.createOrFold<hw::ConstantOp>(
217-
loc, APInt(llvm::Log2_64_Ceil(e), i));
191+
auto loc = op.getLoc();
192+
const auto indexValues = createIndexValuePairs<hw::ArrayGetOp>(
193+
builder, loc, ty, assignOp.getSrc());
194+
for (const auto &[index, srcElement] : indexValues) {
218195
auto dstElement = sv::ArrayIndexInOutOp::create(
219196
builder, loc, assignOp.getDest(), index);
220-
auto srcElement =
221-
hw::ArrayGetOp::create(builder, loc, assignOp.getSrc(), index);
222197
sv::PAssignOp::create(builder, loc, dstElement, srcElement);
223198
}
224199

@@ -251,9 +226,53 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
251226
// Remove original reg.
252227
return true;
253228
})
229+
.Case<comb::MuxOp>([&](comb::MuxOp muxOp) {
230+
// Transform array mux into individual element muxes.
231+
auto ty = hw::type_dyn_cast<hw::ArrayType>(muxOp.getType());
232+
if (!ty)
233+
return false;
234+
235+
OpBuilder builder(muxOp);
236+
237+
auto trueValues = createIndexValuePairs<hw::ArrayGetOp>(
238+
builder, muxOp.getLoc(), ty, muxOp.getTrueValue());
239+
auto falseValues = createIndexValuePairs<hw::ArrayGetOp>(
240+
builder, muxOp.getLoc(), ty, muxOp.getFalseValue());
241+
242+
SmallVector<Value> muxedValues;
243+
244+
for (size_t i = 0, e = trueValues.size(); i < e; i++) {
245+
const auto &[trueIndex, trueValue] = trueValues[i];
246+
const auto &[falseIndex, falseValue] = falseValues[i];
247+
muxedValues.push_back(
248+
comb::MuxOp::create(builder, muxOp.getLoc(), muxOp.getCond(),
249+
trueValue, falseValue, muxOp.getTwoState()));
250+
}
251+
252+
if (!processUsers(op, muxOp.getResult(), muxedValues))
253+
return false;
254+
255+
// Remove original mux.
256+
return true;
257+
})
254258
.Default([&](auto op) { return false; });
255259
}
256260

261+
template <typename ElementType>
262+
SmallVector<std::pair<Value, Value>>
263+
HWLegalizeModulesPass::createIndexValuePairs(OpBuilder &builder,
264+
LocationAttr loc, hw::ArrayType ty,
265+
Value array) {
266+
SmallVector<std::pair<Value, Value>> result;
267+
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
268+
auto index = builder.createOrFold<hw::ConstantOp>(
269+
loc, APInt(llvm::Log2_64_Ceil(e), i));
270+
auto element = ElementType::create(builder, loc, array, index);
271+
result.emplace_back(index, element);
272+
}
273+
return result;
274+
}
275+
257276
Value HWLegalizeModulesPass::lowerLookupToCasez(Operation &op, Value input,
258277
Value index,
259278
mlir::Type elementType,

test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,46 @@ hw.module @array_create_get_default(in %arg0: i8, in %arg1: i8, in %arg2: i8, in
8181
}
8282
}
8383

84+
// CHECK-LABEL: hw.module @array_muxed_create_get_default
85+
hw.module @array_muxed_create_get_default(in %arg0: i8, in %arg1: i8, in %arg2: i8, in %arg3: i8, in %arg4: i8, in %arg5: i8,
86+
in %array_sel: i1, in %index_sel: i2) {
87+
// CHECK: sv.initial {
88+
sv.initial {
89+
%three_array1 = hw.array_create %arg2, %arg1, %arg0 : i8
90+
%three_array2 = hw.array_create %arg5, %arg4, %arg3 : i8
91+
92+
// CHECK: %0 = comb.mux %array_sel, %arg0, %arg3 : i8
93+
// CHECK: %1 = comb.mux %array_sel, %arg1, %arg4 : i8
94+
// CHECK: %2 = comb.mux %array_sel, %arg2, %arg5 : i8
95+
%muxed = comb.mux %array_sel, %three_array1, %three_array2 : !hw.array<3xi8>
96+
97+
// CHECK: %x_i8 = sv.constantX : i8
98+
// CHECK: sv.case casez %index_sel : i2
99+
// CHECK: case b00: {
100+
// CHECK: sv.bpassign %casez_tmp, %0 : i8
101+
// CHECK: }
102+
// CHECK: case b01: {
103+
// CHECK: sv.bpassign %casez_tmp, %1 : i8
104+
// CHECK: }
105+
// CHECK: case b10: {
106+
// CHECK: sv.bpassign %casez_tmp, %2 : i8
107+
// CHECK: }
108+
// CHECK: default: {
109+
// CHECK: sv.bpassign %casez_tmp, %x_i8 : i8
110+
// CHECK: }
111+
112+
// CHECK: %3 = sv.read_inout %casez_tmp : !hw.inout<i8>
113+
%2 = hw.array_get %muxed[%index_sel] : !hw.array<3xi8>, i2
114+
115+
// CHECK: %4 = comb.icmp eq %3, %arg2 : i8
116+
// CHECK: sv.if %4 {
117+
%cond = comb.icmp eq %2, %arg2 : i8
118+
sv.if %cond {
119+
sv.fatal 1
120+
}
121+
}
122+
}
123+
84124
// CHECK-LABEL: hw.module @array_create_concat_get_default
85125
hw.module @array_create_concat_get_default(in %arg0: i8, in %arg1: i8, in %arg2: i8, in %arg3: i8,
86126
in %sel: i2) {
@@ -143,6 +183,38 @@ hw.module @array_constant_get_comb(in %sel: i2, out a: i8) {
143183
hw.output %1 : i8
144184
}
145185

186+
// CHECK-LABEL: hw.module @array_muxed_constant_get_comb
187+
hw.module @array_muxed_constant_get_comb(in %array_sel: i1, in %index_sel: i2, out a: i8) {
188+
// CHECK: %0 = comb.mux %array_sel, %c3_i8, %c7_i8 : i8
189+
// CHECK: %1 = comb.mux %array_sel, %c2_i8, %c6_i8 : i8
190+
// CHECK: %2 = comb.mux %array_sel, %c1_i8, %c5_i8 : i8
191+
// CHECK: %3 = comb.mux %array_sel, %c0_i8, %c4_i8 : i8
192+
// CHECK: %casez_tmp = sv.reg : !hw.inout<i8>
193+
// CHECK: sv.alwayscomb {
194+
// CHECK: sv.case casez %index_sel : i2
195+
// CHECK: case b00: {
196+
// CHECK: sv.bpassign %casez_tmp, %0 : i8
197+
// CHECK: }
198+
// CHECK: case b01: {
199+
// CHECK: sv.bpassign %casez_tmp, %1 : i8
200+
// CHECK: }
201+
// CHECK: case b10: {
202+
// CHECK: sv.bpassign %casez_tmp, %2 : i8
203+
// CHECK: }
204+
// CHECK: default: {
205+
// CHECK: sv.bpassign %casez_tmp, %3 : i8
206+
// CHECK: }
207+
// CHECK: }
208+
%0 = hw.aggregate_constant [0 : i8, 1 : i8, 2 : i8, 3 : i8] : !hw.array<4xi8>
209+
%1 = hw.aggregate_constant [4 : i8, 5 : i8, 6 : i8, 7 : i8] : !hw.array<4xi8>
210+
%muxed = comb.mux %array_sel, %0, %1 : !hw.array<4xi8>
211+
// CHECK: %4 = sv.read_inout %casez_tmp : !hw.inout<i8>
212+
%3 = hw.array_get %muxed[%index_sel] : !hw.array<4xi8>, i2
213+
214+
// CHECK: hw.output %4 : i8
215+
hw.output %3 : i8
216+
}
217+
146218
// CHECK-LABEL: hw.module @array_reg_mux_2
147219
hw.module @array_reg_mux_2(in %clock: i1, in %arg0: i8, in %arg1: i8, in %sel: i1, out a: i8) {
148220
// CHECK: %reg = sv.reg : !hw.inout<i8>

0 commit comments

Comments
 (0)