@@ -556,6 +556,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556556 return SYCLGenError ();
557557 OS () << " , " ;
558558 switch (T->getKind ()) {
559+ case InlineAsmVectorType::v1:
560+ OS () << 1 ;
561+ break ;
559562 case InlineAsmVectorType::v2:
560563 OS () << 2 ;
561564 break ;
@@ -1309,53 +1312,118 @@ class SYCLGen : public SYCLGenBase {
13091312 if (Inst->getNumInputOperands () != 3 )
13101313 return SYCLGenError ();
13111314
1312- if (!Inst->hasAttr (InstAttr::m16n8k16))
1315+ const InlineAsmVectorExpr *DMatVE =
1316+ dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand ());
1317+ if (!DMatVE)
13131318 return SYCLGenError ();
13141319
13151320 // Only row Layout is supported for of A matrix and
13161321 // only col Layout is supported for of B matrix
1317- if (Inst->getAttr (3 ) != InstAttr::row ||
1318- Inst->getAttr (4 ) != InstAttr::col) {
1322+ if (Inst->getAttr (3 ) != InstAttr::row || Inst->getAttr (4 ) != InstAttr::col)
13191323 return SYCLGenError ();
1320- }
13211324
13221325 // Only f16 type is supported for A and B matrix data
1326+ const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (0 ));
13231327 const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (1 ));
13241328 const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (2 ));
1329+ const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (3 ));
13251330
1326- std::string TypeStr;
1327- if (!AType || !BType ||
1328- (AType->getKind () != InlineAsmBuiltinType::f16 ||
1329- BType->getKind () != InlineAsmBuiltinType::f16 )) {
1331+ if (!(AType && BType && CType && DType))
13301332 return SYCLGenError ();
1331- } else {
1332- if (tryEmitType (TypeStr, AType))
1333- return SYCLGenError ();
1334- }
13351333
1336- const InlineAsmVectorExpr *VE =
1337- dyn_cast<InlineAsmVectorExpr>(Inst-> getOutputOperand ());
1338- if (VE && VE-> getNumElements () != 4 ) {
1334+ // Data types of matrix elements for A&B and C&D matrices should be same
1335+ if ((AType-> getKind () != BType-> getKind ()) ||
1336+ (CType-> getKind () != DType-> getKind ()))
13391337 return SYCLGenError ();
1338+
1339+ // Check the validity of AB & CD types
1340+ std::string ABType, CDType;
1341+ if (tryEmitType (ABType, AType))
1342+ return SYCLGenError ();
1343+
1344+ if (tryEmitType (CDType, CType))
1345+ return SYCLGenError ();
1346+
1347+ // Register sizes for vector elements of A, B, C & D matrices
1348+ unsigned NumVecElements[4 ] = {0 };
1349+
1350+ // Data type used to multiply A & B matrices
1351+ std::string MulType;
1352+ if (Inst->hasAttr (InstAttr::m16n8k16)) {
1353+ // Only f16 type is supported for A and B matrix data for m16n8k16
1354+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1355+ // If A matrix type is f16, then C&D matrix types can only be f16
1356+ if (CType->getKind () == AType->getKind ()) {
1357+ NumVecElements[0 ] = 2 ; // A
1358+ NumVecElements[1 ] = 4 ; // B
1359+ NumVecElements[2 ] = 4 ; // C
1360+ NumVecElements[3 ] = 4 ; // D
1361+ } else
1362+ return SYCLGenError ();
1363+ } else
1364+ return SYCLGenError ();
1365+ } else if (Inst->hasAttr (InstAttr::m8n8k4)) {
1366+ // f16 & f64 types are supported for A and B matrix data for m8n8k4
1367+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1368+ // If A matrix type is f16, then C&D matrix types can only be f16/f32
1369+ if (CType->getKind () == AType->getKind ()) {
1370+ NumVecElements[0 ] = 2 ; // A
1371+ NumVecElements[1 ] = 2 ; // B
1372+ NumVecElements[2 ] = 4 ; // C
1373+ NumVecElements[3 ] = 4 ; // D
1374+ } else if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1375+ NumVecElements[0 ] = 2 ; // A
1376+ NumVecElements[1 ] = 2 ; // B
1377+ NumVecElements[2 ] = 8 ; // C
1378+ NumVecElements[3 ] = 8 ; // D
1379+ } else
1380+ return SYCLGenError ();
1381+ } else if (AType->getKind () == InlineAsmBuiltinType::f64 ) {
1382+ // If A matrix type is f64, then C&D matrix types can only be f64
1383+ if (CType->getKind () == AType->getKind ()) {
1384+ NumVecElements[0 ] = 1 ; // A
1385+ NumVecElements[1 ] = 1 ; // B
1386+ NumVecElements[2 ] = 2 ; // C
1387+ NumVecElements[3 ] = 2 ; // D
1388+ } else
1389+ return SYCLGenError ();
1390+ } else
1391+ return SYCLGenError ();
1392+ } else
1393+ return SYCLGenError ();
1394+
1395+ // Check the register sizes for vector elements of A, B, C & D matrices
1396+ for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1397+ InputOp++) {
1398+ if (auto VE =
1399+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1400+ if (VE->getNumElements () != NumVecElements[InputOp])
1401+ return SYCLGenError ();
1402+ } else
1403+ return SYCLGenError ();
13401404 }
1405+ if (DMatVE->getNumElements () != NumVecElements[3 ])
1406+ return SYCLGenError ();
13411407
1408+ MulType = ABType;
13421409 OS () << MapNames::getDpctNamespace () << " experimental::matrix::mma" ;
1343- OS () << " <" << TypeStr << " >(" ;
1410+ OS () << " <" << MulType << " >(" ;
13441411
13451412 // Add D matrix address values to store the MAD result
1346- for (unsigned Inst = 0 ; Inst != VE ->getNumElements (); ++Inst) {
1347- if (isa<InlineAsmDiscardExpr>(VE ->getElement (Inst)))
1413+ for (unsigned Inst = 0 ; Inst != DMatVE ->getNumElements (); ++Inst) {
1414+ if (isa<InlineAsmDiscardExpr>(DMatVE ->getElement (Inst)))
13481415 continue ;
13491416 OS () << " &" ;
1350- if (emitStmt (VE ->getElement (Inst)))
1417+ if (emitStmt (DMatVE ->getElement (Inst)))
13511418 return SYCLGenError ();
13521419 OS () << " , " ;
13531420 }
13541421
13551422 // Add A, B & C matrix values to compute MAD
13561423 for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
13571424 InputOp++) {
1358- if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1425+ if (auto VE =
1426+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
13591427 for (unsigned Inst = 0 ; Inst != VE->getNumElements (); ++Inst) {
13601428 if (isa<InlineAsmDiscardExpr>(VE->getElement (Inst)))
13611429 continue ;
@@ -2607,11 +2675,10 @@ class SYCLGen : public SYCLGenBase {
26072675 Op = std::move (NewOp);
26082676 }
26092677
2610- bool HasHalfOrBfloat16 =
2611- SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2612- DesType->getKind () == InlineAsmBuiltinType::f16 ||
2613- SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2614- DesType->getKind () == InlineAsmBuiltinType::bf16 ;
2678+ bool HasHalfOrBfloat16 = SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2679+ DesType->getKind () == InlineAsmBuiltinType::f16 ||
2680+ SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2681+ DesType->getKind () == InlineAsmBuiltinType::bf16 ;
26152682 if (DpctGlobalInfo::useIntelDeviceMath () && HasHalfOrBfloat16) {
26162683 insertHeader (HeaderType::HT_SYCL_Math);
26172684 if (SrcNeedBitCast)
0 commit comments