@@ -328,6 +328,22 @@ LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getTargetObjectFifo() {
328328 return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget ().getDefiningOp ());
329329};
330330
331+ // ===----------------------------------------------------------------------===//
332+ // AMDAIE_LogicalObjectFifoAccessOp
333+ // ===----------------------------------------------------------------------===//
334+
335+ void LogicalObjectFifoAccessOp::build (OpBuilder &b,
336+ mlir::OperationState &result, Value input,
337+ MemoryAccess accessType) {
338+ auto type = llvm::cast<LogicalObjectFifoType>(input.getType ());
339+ build (b, result, type.getElementType (), input, accessType);
340+ }
341+
342+ LogicalObjectFifoFromMemrefOp
343+ LogicalObjectFifoAccessOp::getLogicalObjectFifo () {
344+ return dyn_cast<LogicalObjectFifoFromMemrefOp>(getInput ().getDefiningOp ());
345+ };
346+
331347// ===----------------------------------------------------------------------===//
332348// AMDAIE_LogicalObjectFifoAcquire
333349// ===----------------------------------------------------------------------===//
@@ -341,6 +357,26 @@ void LogicalObjectFifoAcquire::build(OpBuilder &b, mlir::OperationState &result,
341357// AMDAIE_LogicalObjectFifoFromMemrefOp
342358// ===----------------------------------------------------------------------===//
343359
360+ // / Build with an array of static tile locations.
361+ void LogicalObjectFifoFromMemrefOp::build (
362+ OpBuilder &b, mlir::OperationState &result, Value memref,
363+ ArrayRef<std::pair<int64_t , int64_t >> tileLocations) {
364+ SmallVector<Value> tiles;
365+ tiles.reserve (tileLocations.size ());
366+ for (auto [column, row] : tileLocations) {
367+ auto colIndex = b.create <arith::ConstantIndexOp>(b.getUnknownLoc (), column);
368+ auto rowIndex = b.create <arith::ConstantIndexOp>(b.getUnknownLoc (), row);
369+ auto tileOp =
370+ b.create <AMDAIE::TileOp>(b.getUnknownLoc (), colIndex, rowIndex);
371+ tiles.push_back (tileOp.getResult ());
372+ }
373+ // For deterministic order.
374+ llvm::sort (tiles.begin (), tiles.end (),
375+ TileOp::tileValueColumnAndRowComparator);
376+ auto type = LogicalObjectFifoType::get (cast<MemRefType>(memref.getType ()));
377+ build (b, result, type, memref, tiles);
378+ }
379+
344380LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize (
345381 LogicalObjectFifoFromMemrefOp logicalObjectFifo,
346382 PatternRewriter &rewriter) {
@@ -349,23 +385,19 @@ LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize(
349385 return success ();
350386 }
351387
352- auto comparator = [](Value a, Value b) -> bool {
353- TileOp tileA = dyn_cast<TileOp>(a.getDefiningOp ());
354- TileOp tileB = dyn_cast<TileOp>(b.getDefiningOp ());
355- int64_t colA = getConstantIntValue (tileA.getCol ()).value ();
356- int64_t rowA = getConstantIntValue (tileA.getRow ()).value ();
357- int64_t colB = getConstantIntValue (tileB.getCol ()).value ();
358- int64_t rowB = getConstantIntValue (tileB.getRow ()).value ();
359- if (colA == colB) return rowA < rowB;
360- return colA < colB;
361- };
362388 SmallVector<Value> tiles = logicalObjectFifo.getTiles ();
363- if (llvm::is_sorted (tiles, comparator)) {
389+ if (llvm::is_sorted (tiles, TileOp::tileValueColumnAndRowComparator)) {
390+ // Still erase duplicates.
391+ tiles.erase (std::unique (tiles.begin (), tiles.end ()), tiles.end ());
364392 return success ();
365393 }
366394
367- // If tiles are not sorted, sort them and replace the logical objectfifo
368- llvm::sort (tiles.begin (), tiles.end (), comparator);
395+ // If tiles are not sorted, sort them, erase duplicates and replace the
396+ // logical objectfifo.
397+ llvm::sort (tiles.begin (), tiles.end (),
398+ TileOp::tileValueColumnAndRowComparator);
399+ tiles.erase (std::unique (tiles.begin (), tiles.end ()), tiles.end ());
400+
369401 rewriter.replaceOpWithNewOp <AMDAIE::LogicalObjectFifoFromMemrefOp>(
370402 logicalObjectFifo,
371403 llvm::cast<LogicalObjectFifoType>(
@@ -532,6 +564,23 @@ bool TileOp::hasStaticLocation() {
532564 return getConstantIntValue (getCol ()) && getConstantIntValue (getRow ());
533565}
534566
567+ bool TileOp::tileColumnComparator (AMDAIE::TileOp &a, AMDAIE::TileOp &b) {
568+ int64_t colA = getConstantIntValue (a.getCol ()).value ();
569+ int64_t colB = getConstantIntValue (b.getCol ()).value ();
570+ return colA < colB;
571+ }
572+
573+ bool TileOp::tileValueColumnAndRowComparator (Value a, Value b) {
574+ TileOp tileA = dyn_cast<AMDAIE::TileOp>(a.getDefiningOp ());
575+ TileOp tileB = dyn_cast<AMDAIE::TileOp>(b.getDefiningOp ());
576+ int64_t colA = getConstantIntValue (tileA.getCol ()).value ();
577+ int64_t rowA = getConstantIntValue (tileA.getRow ()).value ();
578+ int64_t colB = getConstantIntValue (tileB.getCol ()).value ();
579+ int64_t rowB = getConstantIntValue (tileB.getRow ()).value ();
580+ if (colA == colB) return rowA < rowB;
581+ return colA < colB;
582+ };
583+
535584// ===----------------------------------------------------------------------===//
536585// AMDAIE_WorkgroupOp
537586// ===----------------------------------------------------------------------===//
0 commit comments