@@ -48,6 +48,7 @@ limitations under the License.
4848#include " mlir/include/mlir/IR/Attributes.h"
4949#include " mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
5050#include " mlir/include/mlir/IR/OpDefinition.h"
51+ #include " mlir/include/mlir/IR/Visitors.h"
5152#include " jaxlib/mosaic/dialect/tpu/layout.h"
5253#include " jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
5354#include " xla/layout.h"
@@ -122,10 +123,7 @@ class VectorLayoutInferer {
122123
123124 LogicalResult inferBlock (
124125 Block &block,
125- const std::function<LogicalResult(Operation *)> &match_terminator,
126- // TODO(jevinjiang): Propagate this flag deeper because it won't work when
127- // there is an op with blocks inside this block.
128- bool override_layout = false) {
126+ const std::function<LogicalResult(Operation *)> &match_terminator) {
129127 for (Operation &any_op : block.without_terminator ()) {
130128 VLOG (kLayoutLog ) << Print (&any_op);
131129 if (any_op.hasAttr (" in_layout" ) || any_op.hasAttr (" out_layout" )) {
@@ -134,8 +132,6 @@ class VectorLayoutInferer {
134132 any_op.hasAttr (" in_layout" ) && any_op.hasAttr (" out_layout" ),
135133 " expect layout attributes in tpu::AssumeLayoutOp" );
136134 continue ;
137- } else if (override_layout) {
138- // Intend to override the layouts attribute.
139135 } else {
140136 any_op.emitOpError (" layout attributes already attached" );
141137 return failure ();
@@ -508,35 +504,15 @@ class VectorLayoutInferer {
508504 op->getNumOperands () == 3 + op.getNumResults (),
509505 " expected num_operands is equal to 3 + num_results in scf.for" );
510506
511- SmallVector<Layout, 4 > in_layouts = getLayoutFromOperands (op);
512- // Drop the first 3 layouts for lower bound, upper bound and step.
513- ArrayRef<Layout> arg_layouts = ArrayRef<Layout>(in_layouts).drop_front (3 );
514- SmallVector<tpu::AssumeLayoutOp, 4 > assume_layout_ops;
515- assume_layout_ops.reserve (arg_layouts.size ());
516- // Use tpu.assume_layout to annotate every block argument with the layout of
517- // the corresponding operand in forOp and replace all uses of the block
518- // argument with the result of tpu.assume_layout.
519- ImplicitLocOpBuilder builder =
520- ImplicitLocOpBuilder::atBlockBegin (op.getLoc (), op.getBody ());
521-
522- // Drop the induction_variable and layouts of bounds+step (respectively).
523- for (auto [iter_arg, layout] : llvm::zip_equal (
524- op.getBody ()->getArguments ().drop_front (1 ), arg_layouts)) {
525- if (!dyn_cast<VectorType>(iter_arg.getType ())) {
526- assume_layout_ops.push_back (nullptr );
527- continue ;
528- }
529- auto assume_layout_op =
530- builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
531- setLayout (assume_layout_op, layout, layout);
532- assume_layout_ops.push_back (assume_layout_op);
533- iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
534- return operand.getOwner () != assume_layout_op;
535- });
536- }
537-
538- if (inferBlock (*op.getBody (), match_yield).failed ()) {
539- return failure ();
507+ auto in_layouts = getLayoutFromOperands (op);
508+ // Drop the input layouts for lower bound, upper bound. But keep the layout
509+ // for step because it matches with induction variable in arguments.
510+ auto arg_layouts = ArrayRef<Layout>(in_layouts).drop_front (2 );
511+ if (assumeLayoutsForBlockArgs (*op.getBody (), arg_layouts).failed () ||
512+ inferBlock (*op.getBody (), match_yield).failed ()) {
513+ return op.emitOpError (
514+ " failed to infer layout with initial layouts for body in "
515+ " scf.for op" );
540516 }
541517 auto yield_op = op.getBody ()->getTerminator ();
542518 auto yield_in_layouts = getLayoutFromOperands (yield_op);
@@ -546,7 +522,8 @@ class VectorLayoutInferer {
546522 int out_idx = 0 ;
547523 bool require_reinfer = false ;
548524 for (auto [in_layout, yield_layout, result] :
549- llvm::zip_equal (ArrayRef<Layout>(in_layouts).drop_front (3 ),
525+ llvm::zip_equal (arg_layouts.drop_front (
526+ 1 ), // Drop the layout for induction variable.
550527 yield_in_layouts, op.getResults ())) {
551528 if (auto vty = dyn_cast<VectorType>(result.getType ())) {
552529 if (!in_layout.has_value ()) {
@@ -586,24 +563,25 @@ class VectorLayoutInferer {
586563 ++out_idx;
587564 }
588565 if (require_reinfer) {
566+ // Force same layouts in input layout but skip the first 3 layouts for
567+ // lower bound, upper bound and step.
568+ std::copy (out_layouts.begin (), out_layouts.end (), in_layouts.begin () + 3 );
569+
589570 // Terminator in the loop will carry layouts to the next loop but
590571 // the loop's block args' layouts are determined by the initial inputs. We
591572 // need to force the same layouts for all in order to make layouts be
592573 // consistent across all branches. To ensure that, we need to reprocess
593574 // layout inference for the entire body with the final consolidated
594575 // layout.
595- for ( int64_t i = 0 ; i < out_layouts. size (); ++i) {
596- if (assume_layout_ops[i]) {
597- setLayout (assume_layout_ops[i], out_layouts[i], out_layouts[i]);
598- }
599- }
600- if ( inferBlock (* op.getBody (), match_yield, /* override_layout= */ true )
601- . failed ()) {
602- return op. emitOpError ( " failed to infer layout for scf.for op" );
576+ clearBlockLayouts (*op. getBody ());
577+ if (assumeLayoutsForBlockArgs (*op. getBody (),
578+ ArrayRef<Layout>(in_layouts). drop_front ( 2 ))
579+ . failed () ||
580+ inferBlock (*op. getBody (), match_yield). failed ()) {
581+ return op.emitOpError (
582+ " failed to infer layout with compatible layouts for body in "
583+ " scf.for op" );
603584 }
604- std::copy (out_layouts.begin (), out_layouts.end (),
605- in_layouts.begin () + 3 ); // Skip first 3 layouts for lower
606- // bound, upper bound and step.
607585 }
608586 setInLayout (yield_op, out_layouts);
609587 setLayout (op, in_layouts, out_layouts);
@@ -622,53 +600,19 @@ class VectorLayoutInferer {
622600 TPU_CHECK_OP (op.getNumRegions () == 2 , " expected two blocks for scf.while" );
623601
624602 SmallVector<Layout, 4 > in_layouts = getLayoutFromOperands (op);
625- SmallVector<tpu::AssumeLayoutOp, 4 > before_assume_layout_ops;
626- before_assume_layout_ops.reserve (in_layouts.size ());
627- SmallVector<tpu::AssumeLayoutOp, 4 > after_assume_layout_ops;
628- after_assume_layout_ops.reserve (in_layouts.size ());
629603
630- // Use tpu.assume_layout to annotate every block argument with the layout of
631- // the corresponding operand in WhileOp and replace all uses of the block
632- // argument with the result of tpu.assume_layout.
633- ImplicitLocOpBuilder builder =
634- ImplicitLocOpBuilder::atBlockBegin (op.getLoc (), op.getBeforeBody ());
635- for (auto [iter_arg, layout] :
636- llvm::zip_equal (op.getBeforeBody ()->getArguments (), in_layouts)) {
637- if (!dyn_cast<VectorType>(iter_arg.getType ())) {
638- before_assume_layout_ops.push_back (nullptr );
639- continue ;
640- }
641- auto assume_layout_op =
642- builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
643- setLayout (assume_layout_op, layout, layout);
644- before_assume_layout_ops.push_back (assume_layout_op);
645- iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
646- return operand.getOwner () != assume_layout_op;
647- });
648- }
649- if (inferBlock (*op.getBeforeBody (), match_condition).failed ()) {
650- return failure ();
604+ if (assumeLayoutsForBlockArgs (*op.getBeforeBody (), in_layouts).failed () ||
605+ inferBlock (*op.getBeforeBody (), match_condition).failed ()) {
606+ return op.emitOpError (
607+ " failed to infer layout with initial layouts for before body in "
608+ " scf.while op" );
651609 }
652610
653- builder =
654- ImplicitLocOpBuilder::atBlockBegin (op.getLoc (), op.getAfterBody ());
655- for (auto [iter_arg, layout] :
656- llvm::zip_equal (op.getAfterBody ()->getArguments (), in_layouts)) {
657- if (!dyn_cast<VectorType>(iter_arg.getType ())) {
658- after_assume_layout_ops.push_back (nullptr );
659- continue ;
660- }
661- auto assume_layout_op =
662- builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
663- setLayout (assume_layout_op, layout, layout);
664- after_assume_layout_ops.push_back (assume_layout_op);
665- iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
666- return operand.getOwner () != assume_layout_op;
667- });
668- }
669-
670- if (inferBlock (*op.getAfterBody (), match_yield).failed ()) {
671- return failure ();
611+ if (assumeLayoutsForBlockArgs (*op.getAfterBody (), in_layouts).failed () ||
612+ inferBlock (*op.getAfterBody (), match_yield).failed ()) {
613+ return op.emitOpError (
614+ " failed to infer layout with initial layouts for after body in "
615+ " scf.while op" );
672616 }
673617
674618 auto *cond_op = op.getBeforeBody ()->getTerminator ();
@@ -738,27 +682,26 @@ class VectorLayoutInferer {
738682 ++out_idx;
739683 }
740684 if (require_reinfer) {
685+ clearBlockLayouts (*op.getBeforeBody ());
686+ clearBlockLayouts (*op.getAfterBody ());
741687 // Terminator in the loop will carry layouts to the next loop but
742688 // the loop's block args' layouts are determined by the initial inputs. We
743689 // need to force the same layouts for all in order to make layouts be
744690 // consistent across all branches. To ensure that, we need to reprocess
745691 // layout inference for the entire body with the final consolidated
746692 // layout.
747- for (int64_t i = 0 ; i < out_layouts.size (); ++i) {
748- if (before_assume_layout_ops[i]) {
749- setLayout (before_assume_layout_ops[i], out_layouts[i],
750- out_layouts[i]);
751- }
752- if (after_assume_layout_ops[i]) {
753- setLayout (after_assume_layout_ops[i], out_layouts[i], out_layouts[i]);
754- }
755- }
756- if (inferBlock (*op.getBeforeBody (), match_condition,
757- /* override_layout=*/ true )
693+ if (assumeLayoutsForBlockArgs (*op.getBeforeBody (), out_layouts)
758694 .failed () ||
759- inferBlock (*op.getAfterBody (), match_yield, /* override_layout=*/ true )
760- .failed ()) {
761- return op.emitOpError (" failed to infer layout for scf.while op" );
695+ inferBlock (*op.getBeforeBody (), match_condition).failed ()) {
696+ return op.emitOpError (
697+ " failed to infer layout with compatible layouts for before body in "
698+ " scf.while op" );
699+ }
700+ if (assumeLayoutsForBlockArgs (*op.getAfterBody (), out_layouts).failed () ||
701+ inferBlock (*op.getAfterBody (), match_yield).failed ()) {
702+ return op.emitOpError (
703+ " failed to infer layout with compatible layouts for after body in "
704+ " scf.while op" );
762705 }
763706 }
764707 std::copy (out_layouts.begin (), out_layouts.end (),
@@ -1854,6 +1797,53 @@ class VectorLayoutInferer {
18541797 return true ;
18551798 }
18561799
1800+ LogicalResult assumeLayoutsForBlockArgs (Block &block,
1801+ ArrayRef<Layout> layouts) {
1802+ auto op = block.getParentOp ();
1803+ if (layouts.size () != block.getNumArguments ()) {
1804+ return op->emitOpError (
1805+ " Block arguments must have the same number of layouts" );
1806+ }
1807+ // Use tpu.assume_layout to annotate every block argument with the layout of
1808+ // the corresponding operand and replace all uses of the block argument with
1809+ // the result of tpu.assume_layout.
1810+ ImplicitLocOpBuilder builder =
1811+ ImplicitLocOpBuilder::atBlockBegin (op->getLoc (), &block);
1812+ for (auto [iter_arg, layout] :
1813+ llvm::zip_equal (block.getArguments (), layouts)) {
1814+ if (!dyn_cast<VectorType>(iter_arg.getType ())) {
1815+ continue ;
1816+ }
1817+ if (llvm::any_of (iter_arg.getUsers (), [](Operation *user) {
1818+ return isa<tpu::AssumeLayoutOp>(user);
1819+ })) {
1820+ return op->emitOpError (" Expected no assume layout for block arguments" );
1821+ }
1822+ auto assume_layout_op =
1823+ builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
1824+ setLayout (assume_layout_op, layout, layout);
1825+ iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
1826+ return operand.getOwner () != assume_layout_op;
1827+ });
1828+ }
1829+ return success ();
1830+ }
1831+
1832+ void clearBlockLayouts (Block &block) {
1833+ block.walk ([&](Operation *op) {
1834+ // We need to remove assume_layout ops in each block. Otherwise, we will
1835+ // create extra assume_layout ops for nested blocks.
1836+ if (auto assume_op = dyn_cast<tpu::AssumeLayoutOp>(op)) {
1837+ assume_op.getResult ().replaceAllUsesWith (assume_op.getInput ());
1838+ assume_op->erase ();
1839+ return WalkResult::advance ();
1840+ }
1841+ op->removeAttr (" in_layout" );
1842+ op->removeAttr (" out_layout" );
1843+ return WalkResult::advance ();
1844+ });
1845+ }
1846+
18571847 void setInLayout (Operation *op, ArrayRef<Layout> in) {
18581848 CHECK_EQ (in.size (), op->getNumOperands ()) << Print (op);
18591849 SmallVector<Attribute, 4 > in_attrs;
0 commit comments