1414//
1515// ===----------------------------------------------------------------------===//
1616
17+ #include " circt/Dialect/Comb/CombOps.h"
1718#include " circt/Dialect/HW/HWOps.h"
1819#include " circt/Dialect/HW/HWTypes.h"
1920#include " circt/Dialect/SV/SVOps.h"
@@ -42,6 +43,10 @@ struct HWLegalizeModulesPass
4243private:
4344 void processPostOrder (Block &block);
4445 bool tryLoweringPackedArrayOp (Operation &op);
46+ template <typename ElementType>
47+ SmallVector<std::pair<Value, Value>>
48+ createIndexValuePairs (OpBuilder &builder, LocationAttr loc, hw::ArrayType ty,
49+ Value array);
4550 Value lowerLookupToCasez (Operation &op, Value input, Value index,
4651 mlir::Type elementType,
4752 SmallVector<Value> caseValues);
@@ -89,43 +94,20 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
8994 })
9095 .Case <hw::ArrayConcatOp>([&](hw::ArrayConcatOp concatOp) {
9196 // Redirect individual element uses (if any) to the input arguments.
92- SmallVector<std::pair<Value, uint64_t >> arrays;
97+ SmallVector<Value> values;
98+ OpBuilder builder (concatOp);
9399 for (auto array : llvm::reverse (concatOp.getInputs ())) {
94100 auto ty = hw::type_cast<hw::ArrayType>(array.getType ());
95- arrays.emplace_back (array, ty.getNumElements ());
96- }
97- for (auto *user :
98- llvm::make_early_inc_range (concatOp.getResult ().getUsers ())) {
99- if (TypeSwitch<Operation *, bool >(user)
100- .Case <hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
101- if (auto indexAndBitWidth =
102- tryExtractIndexAndBitWidth (getOp.getIndex ())) {
103- auto [indexValue, bitWidth] = *indexAndBitWidth;
104- // FIXME: More efficient search
105- for (const auto &[array, size] : arrays) {
106- if (indexValue >= size) {
107- indexValue -= size;
108- continue ;
109- }
110- OpBuilder builder (getOp);
111- getOp.getInputMutable ().set (array);
112- getOp.getIndexMutable ().set (
113- builder.createOrFold <hw::ConstantOp>(
114- getOp.getLoc (), APInt (bitWidth, indexValue)));
115- return true ;
116- }
117- }
118-
119- return false ;
120- })
121- .Default ([](auto op) { return false ; }))
122- continue ;
123-
124- op.emitError (" unsupported packed array expression" );
125- signalPassFailure ();
101+ const auto indexValues = createIndexValuePairs<hw::ArrayGetOp>(
102+ builder, concatOp.getLoc (), ty, array);
103+ for (const auto &[_, value] : indexValues) {
104+ values.push_back (value);
105+ }
126106 }
107+ if (!processUsers (op, concatOp.getResult (), values))
108+ return false ;
127109
128- // Remove the original op.
110+ // Remove original op.
129111 return true ;
130112 })
131113 .Case <hw::ArrayCreateOp>([&](hw::ArrayCreateOp createOp) {
@@ -147,15 +129,12 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
147129 // Generate case value element lookups.
148130 auto ty = hw::type_cast<hw::ArrayType>(getOp.getInput ().getType ());
149131 OpBuilder builder (getOp);
132+ auto loc = op.getLoc ();
133+ const auto indexValues = createIndexValuePairs<hw::ArrayGetOp>(
134+ builder, loc, ty, getOp.getInput ());
150135 SmallVector<Value> caseValues;
151- for (size_t i = 0 , e = ty.getNumElements (); i < e; i++) {
152- auto loc = op.getLoc ();
153- auto index = builder.createOrFold <hw::ConstantOp>(
154- loc, APInt (llvm::Log2_64_Ceil (e), i));
155- auto element =
156- hw::ArrayGetOp::create (builder, loc, getOp.getInput (), index);
157- caseValues.push_back (element);
158- }
136+ for (const auto &[_, value] : indexValues)
137+ caseValues.push_back (value);
159138
160139 // Transform array index op into casez statement.
161140 auto theWire = lowerLookupToCasez (op, getOp.getInput (), index,
@@ -183,13 +162,11 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
183162 // Generate case value element lookups.
184163 auto ty = hw::type_cast<hw::ArrayType>(inout.getElementType ());
185164 OpBuilder builder (&op);
165+ auto loc = op.getLoc ();
166+ const auto indexValues = createIndexValuePairs<sv::ArrayIndexInOutOp>(
167+ builder, loc, ty, indexOp.getInput ());
186168 SmallVector<Value> caseValues;
187- for (size_t i = 0 , e = ty.getNumElements (); i < e; i++) {
188- auto loc = op.getLoc ();
189- auto index = builder.createOrFold <hw::ConstantOp>(
190- loc, APInt (llvm::Log2_64_Ceil (e), i));
191- auto element = sv::ArrayIndexInOutOp::create (
192- builder, loc, indexOp.getInput (), index);
169+ for (const auto &[_, element] : indexValues) {
193170 auto readElement = sv::ReadInOutOp::create (builder, loc, element);
194171 caseValues.push_back (readElement);
195172 }
@@ -211,14 +188,12 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
211188 return false ;
212189
213190 OpBuilder builder (assignOp);
214- for ( size_t i = 0 , e = ty. getNumElements (); i < e; i++) {
215- auto loc = op. getLoc ();
216- auto index = builder. createOrFold <hw::ConstantOp>(
217- loc, APInt ( llvm::Log2_64_Ceil (e), i));
191+ auto loc = op. getLoc ();
192+ const auto indexValues = createIndexValuePairs<hw::ArrayGetOp>(
193+ builder, loc, ty, assignOp. getSrc ());
194+ for ( const auto &[index, srcElement] : indexValues) {
218195 auto dstElement = sv::ArrayIndexInOutOp::create (
219196 builder, loc, assignOp.getDest (), index);
220- auto srcElement =
221- hw::ArrayGetOp::create (builder, loc, assignOp.getSrc (), index);
222197 sv::PAssignOp::create (builder, loc, dstElement, srcElement);
223198 }
224199
@@ -251,9 +226,53 @@ bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
251226 // Remove original reg.
252227 return true ;
253228 })
229+ .Case <comb::MuxOp>([&](comb::MuxOp muxOp) {
230+ // Transform array mux into individual element muxes.
231+ auto ty = hw::type_dyn_cast<hw::ArrayType>(muxOp.getType ());
232+ if (!ty)
233+ return false ;
234+
235+ OpBuilder builder (muxOp);
236+
237+ auto trueValues = createIndexValuePairs<hw::ArrayGetOp>(
238+ builder, muxOp.getLoc (), ty, muxOp.getTrueValue ());
239+ auto falseValues = createIndexValuePairs<hw::ArrayGetOp>(
240+ builder, muxOp.getLoc (), ty, muxOp.getFalseValue ());
241+
242+ SmallVector<Value> muxedValues;
243+
244+ for (size_t i = 0 , e = trueValues.size (); i < e; i++) {
245+ const auto &[trueIndex, trueValue] = trueValues[i];
246+ const auto &[falseIndex, falseValue] = falseValues[i];
247+ muxedValues.push_back (
248+ comb::MuxOp::create (builder, muxOp.getLoc (), muxOp.getCond (),
249+ trueValue, falseValue, muxOp.getTwoState ()));
250+ }
251+
252+ if (!processUsers (op, muxOp.getResult (), muxedValues))
253+ return false ;
254+
255+ // Remove original mux.
256+ return true ;
257+ })
254258 .Default ([&](auto op) { return false ; });
255259}
256260
261+ template <typename ElementType>
262+ SmallVector<std::pair<Value, Value>>
263+ HWLegalizeModulesPass::createIndexValuePairs (OpBuilder &builder,
264+ LocationAttr loc, hw::ArrayType ty,
265+ Value array) {
266+ SmallVector<std::pair<Value, Value>> result;
267+ for (size_t i = 0 , e = ty.getNumElements (); i < e; i++) {
268+ auto index = builder.createOrFold <hw::ConstantOp>(
269+ loc, APInt (llvm::Log2_64_Ceil (e), i));
270+ auto element = ElementType::create (builder, loc, array, index);
271+ result.emplace_back (index, element);
272+ }
273+ return result;
274+ }
275+
257276Value HWLegalizeModulesPass::lowerLookupToCasez (Operation &op, Value input,
258277 Value index,
259278 mlir::Type elementType,
0 commit comments