@@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) {
514514bool SYCLGenBase::emitBuiltinType (const InlineAsmBuiltinType *T) {
515515 switch (T->getKind ()) {
516516 // clang-format off
517+ case InlineAsmBuiltinType::b1: OS () << " uint8_t" ; break ;
517518 case InlineAsmBuiltinType::b8: OS () << " uint8_t" ; break ;
518519 case InlineAsmBuiltinType::b16: OS () << " uint16_t" ; break ;
519520 case InlineAsmBuiltinType::b32: OS () << " uint32_t" ; break ;
520521 case InlineAsmBuiltinType::b64: OS () << " uint64_t" ; break ;
522+ case InlineAsmBuiltinType::u4: OS () << " uint8_t" ; break ;
521523 case InlineAsmBuiltinType::u8 : OS () << " uint8_t" ; break ;
522524 case InlineAsmBuiltinType::u16 : OS () << " uint16_t" ; break ;
523525 case InlineAsmBuiltinType::u32 : OS () << " uint32_t" ; break ;
524526 case InlineAsmBuiltinType::u64 : OS () << " uint64_t" ; break ;
527+ case InlineAsmBuiltinType::s4: OS () << " int8_t" ; break ;
525528 case InlineAsmBuiltinType::s8: OS () << " int8_t" ; break ;
526529 case InlineAsmBuiltinType::s16: OS () << " int16_t" ; break ;
527530 case InlineAsmBuiltinType::s32: OS () << " int32_t" ; break ;
@@ -1347,44 +1350,276 @@ class SYCLGen : public SYCLGenBase {
13471350 // Register sizes for vector elements of A, B, C & D matrices
13481351 unsigned NumVecElements[4 ] = {0 };
13491352
1353+ // Sizes of A & B matrices
1354+ std::string M, N, K;
1355+
1356+ // Operator for m8n8k128/m16n8k128/m16n8k256
1357+ std::string MatrixOp;
1358+
13501359 // Data type used to multiply A & B matrices
13511360 std::string MulType;
1352- if (Inst->hasAttr (InstAttr::m16n8k16)) {
1353- // Only f16 type is supported for A and B matrix data for m16n8k16
1361+ if (Inst->hasAttr (InstAttr::m8n8k4)) {
1362+ M = " 8" ;
1363+ N = " 8" ;
1364+ K = " 4" ;
1365+ // f16 & f64 types are supported for A and B matrices of m8n8k4
13541366 if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1355- // If A matrix type is f16, then C&D matrix types can only be f16
1367+ // If A matrix type is f16, then C&D matrix types can only be f16/f32
13561368 if (CType->getKind () == AType->getKind ()) {
13571369 NumVecElements[0 ] = 2 ; // A
1358- NumVecElements[1 ] = 4 ; // B
1370+ NumVecElements[1 ] = 2 ; // B
13591371 NumVecElements[2 ] = 4 ; // C
13601372 NumVecElements[3 ] = 4 ; // D
1373+ } else if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1374+ NumVecElements[0 ] = 2 ; // A
1375+ NumVecElements[1 ] = 2 ; // B
1376+ NumVecElements[2 ] = 8 ; // C
1377+ NumVecElements[3 ] = 8 ; // D
1378+ } else
1379+ return SYCLGenError ();
1380+ } else if (AType->getKind () == InlineAsmBuiltinType::f64 ) {
1381+ // If A matrix type is f64, then C&D matrix types can only be f64
1382+ if (CType->getKind () == AType->getKind ()) {
1383+ NumVecElements[0 ] = 1 ; // A
1384+ NumVecElements[1 ] = 1 ; // B
1385+ NumVecElements[2 ] = 2 ; // C
1386+ NumVecElements[3 ] = 2 ; // D
1387+ } else
1388+ return SYCLGenError ();
1389+ } else
1390+ return SYCLGenError ();
1391+ } else if (Inst->hasAttr (InstAttr::m8n8k16)) {
1392+ M = " 8" ;
1393+ N = " 8" ;
1394+ K = " 16" ;
1395+ // Only s8/u8 types are supported for A and B matrices of m8n8k16
1396+ if (AType->getKind () == InlineAsmBuiltinType::s8 ||
1397+ AType->getKind () == InlineAsmBuiltinType::u8 ) {
1398+ // If A matrix type is s8/u8, then C&D matrix types can only be s32
1399+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1400+ NumVecElements[0 ] = 1 ; // A
1401+ NumVecElements[1 ] = 1 ; // B
1402+ NumVecElements[2 ] = 2 ; // C
1403+ NumVecElements[3 ] = 2 ; // D
13611404 } else
13621405 return SYCLGenError ();
13631406 } else
13641407 return SYCLGenError ();
1365- } else if (Inst->hasAttr (InstAttr::m8n8k4)) {
1366- // f16 & f64 types are supported for A and B matrix data for m8n8k4
1408+ } else if (Inst->hasAttr (InstAttr::m8n8k32)) {
1409+ M = " 8" ;
1410+ N = " 8" ;
1411+ K = " 32" ;
1412+ // Only s4/u4 types are supported for A and B matrices of m16n8k32
1413+ if (AType->getKind () == InlineAsmBuiltinType::s4 ||
1414+ AType->getKind () == InlineAsmBuiltinType::u4) {
1415+ // If A matrix type is s4/u4, then C&D matrix types can only be s32
1416+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1417+ NumVecElements[0 ] = 1 ; // A
1418+ NumVecElements[1 ] = 1 ; // B
1419+ NumVecElements[2 ] = 2 ; // C
1420+ NumVecElements[3 ] = 2 ; // D
1421+ } else
1422+ return SYCLGenError ();
1423+ } else
1424+ return SYCLGenError ();
1425+ } else if (Inst->hasAttr (InstAttr::m8n8k128)) {
1426+ M = " 8" ;
1427+ N = " 8" ;
1428+ K = " 128" ;
1429+ // Only b1 type is supported for A and B matrices of m16n8k128
1430+ if (AType->getKind () == InlineAsmBuiltinType::b1) {
1431+ // If A matrix type is b1, then C&D matrix types can only be s32
1432+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1433+ NumVecElements[0 ] = 1 ; // A
1434+ NumVecElements[1 ] = 1 ; // B
1435+ NumVecElements[2 ] = 2 ; // C
1436+ NumVecElements[3 ] = 2 ; // D
1437+
1438+ // Only and/xor bitwise operations are supported for m8n8k128
1439+ if (Inst->hasAttr (InstAttr::op_and))
1440+ MatrixOp = " and" ;
1441+ else if (Inst->hasAttr (InstAttr::op_xor))
1442+ MatrixOp = " xor" ;
1443+ else
1444+ return SYCLGenError ();
1445+ } else
1446+ return SYCLGenError ();
1447+ } else
1448+ return SYCLGenError ();
1449+ } else if (Inst->hasAttr (InstAttr::m16n8k4)) {
1450+ M = " 16" ;
1451+ N = " 8" ;
1452+ K = " 4" ;
1453+ // Only f64 type is supported for A and B matrices of m16n8k4
1454+ if (AType->getKind () == InlineAsmBuiltinType::f64 ) {
1455+ // If A matrix type is f64, then C&D matrix types can only be f64
1456+ if (CType->getKind () == InlineAsmBuiltinType::f64 ) {
1457+ NumVecElements[0 ] = 2 ; // A
1458+ NumVecElements[1 ] = 1 ; // B
1459+ NumVecElements[2 ] = 4 ; // C
1460+ NumVecElements[3 ] = 4 ; // D
1461+ } else
1462+ return SYCLGenError ();
1463+ } else
1464+ return SYCLGenError ();
1465+ } else if (Inst->hasAttr (InstAttr::m16n8k8)) {
1466+ M = " 16" ;
1467+ N = " 8" ;
1468+ K = " 8" ;
1469+ // Only f16/f64 types are supported for A and B matrices of m16n8k8
13671470 if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
13681471 // If A matrix type is f16, then C&D matrix types can only be f16/f32
1369- if (CType->getKind () == AType-> getKind () ) {
1472+ if (CType->getKind () == InlineAsmBuiltinType:: f16 ) {
13701473 NumVecElements[0 ] = 2 ; // A
1474+ NumVecElements[1 ] = 1 ; // B
1475+ NumVecElements[2 ] = 2 ; // C
1476+ NumVecElements[3 ] = 2 ; // D
1477+ } else if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1478+ NumVecElements[0 ] = 2 ; // A
1479+ NumVecElements[1 ] = 1 ; // B
1480+ NumVecElements[2 ] = 4 ; // C
1481+ NumVecElements[3 ] = 4 ; // D
1482+ } else
1483+ return SYCLGenError ();
1484+ } else if (AType->getKind () == InlineAsmBuiltinType::f64 ) {
1485+ // If A matrix type is f64, then C&D matrix types can only be f64
1486+ if (CType->getKind () == InlineAsmBuiltinType::f64 ) {
1487+ NumVecElements[0 ] = 4 ; // A
13711488 NumVecElements[1 ] = 2 ; // B
13721489 NumVecElements[2 ] = 4 ; // C
13731490 NumVecElements[3 ] = 4 ; // D
1491+ } else
1492+ return SYCLGenError ();
1493+ } else
1494+ return SYCLGenError ();
1495+ } else if (Inst->hasAttr (InstAttr::m16n8k16)) {
1496+ M = " 16" ;
1497+ N = " 8" ;
1498+ K = " 16" ;
1499+ // Only f16/f64/s8/u8 type is supported for A and B matrices of m16n8k16
1500+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1501+ // If A matrix type is f16, then C&D matrix types can only be f16/f32
1502+ if (CType->getKind () == AType->getKind ()) {
1503+ NumVecElements[0 ] = 4 ; // A
1504+ NumVecElements[1 ] = 2 ; // B
1505+ NumVecElements[2 ] = 2 ; // C
1506+ NumVecElements[3 ] = 2 ; // D
13741507 } else if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1375- NumVecElements[0 ] = 2 ; // A
1508+ NumVecElements[0 ] = 4 ; // A
13761509 NumVecElements[1 ] = 2 ; // B
1377- NumVecElements[2 ] = 8 ; // C
1378- NumVecElements[3 ] = 8 ; // D
1510+ NumVecElements[2 ] = 4 ; // C
1511+ NumVecElements[3 ] = 4 ; // D
13791512 } else
13801513 return SYCLGenError ();
13811514 } else if (AType->getKind () == InlineAsmBuiltinType::f64 ) {
13821515 // If A matrix type is f64, then C&D matrix types can only be f64
13831516 if (CType->getKind () == AType->getKind ()) {
1384- NumVecElements[0 ] = 1 ; // A
1517+ NumVecElements[0 ] = 8 ; // A
1518+ NumVecElements[1 ] = 4 ; // B
1519+ NumVecElements[2 ] = 4 ; // C
1520+ NumVecElements[3 ] = 4 ; // D
1521+ } else
1522+ return SYCLGenError ();
1523+ } else if (AType->getKind () == InlineAsmBuiltinType::s8 ||
1524+ AType->getKind () == InlineAsmBuiltinType::u8 ) {
1525+ // If A matrix type is s8/u8, then C&D matrix types can only be s32
1526+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1527+ NumVecElements[0 ] = 2 ; // A
13851528 NumVecElements[1 ] = 1 ; // B
1386- NumVecElements[2 ] = 2 ; // C
1387- NumVecElements[3 ] = 2 ; // D
1529+ NumVecElements[2 ] = 4 ; // C
1530+ NumVecElements[3 ] = 4 ; // D
1531+ } else
1532+ return SYCLGenError ();
1533+ } else
1534+ return SYCLGenError ();
1535+ } else if (Inst->hasAttr (InstAttr::m16n8k32)) {
1536+ M = " 16" ;
1537+ N = " 8" ;
1538+ K = " 32" ;
1539+ // Only s4/s8/u4/u8 types are supported for A and B matrices of m16n8k32
1540+ if (AType->getKind () == InlineAsmBuiltinType::s4 ||
1541+ AType->getKind () == InlineAsmBuiltinType::u4) {
1542+ // If A matrix type is s4/u4, then C&D matrix types can only be s32
1543+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1544+ NumVecElements[0 ] = 2 ; // A
1545+ NumVecElements[1 ] = 1 ; // B
1546+ NumVecElements[2 ] = 4 ; // C
1547+ NumVecElements[3 ] = 4 ; // D
1548+ } else
1549+ return SYCLGenError ();
1550+ } else if (AType->getKind () == InlineAsmBuiltinType::s8 ||
1551+ AType->getKind () == InlineAsmBuiltinType::u8 ) {
1552+ // If A matrix type is s8/u8, then C&D matrix types can only be s32
1553+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1554+ NumVecElements[0 ] = 4 ; // A
1555+ NumVecElements[1 ] = 2 ; // B
1556+ NumVecElements[2 ] = 4 ; // C
1557+ NumVecElements[3 ] = 4 ; // D
1558+ } else
1559+ return SYCLGenError ();
1560+ } else
1561+ return SYCLGenError ();
1562+ } else if (Inst->hasAttr (InstAttr::m16n8k64)) {
1563+ M = " 16" ;
1564+ N = " 8" ;
1565+ K = " 64" ;
1566+ // Only s4/u4 types are supported for A and B matrices of m16n8k64
1567+ if (AType->getKind () == InlineAsmBuiltinType::s4 ||
1568+ AType->getKind () == InlineAsmBuiltinType::u4) {
1569+ // If A matrix type is s4/u4, then C&D matrix types can only be s32
1570+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1571+ NumVecElements[0 ] = 4 ; // A
1572+ NumVecElements[1 ] = 2 ; // B
1573+ NumVecElements[2 ] = 4 ; // C
1574+ NumVecElements[3 ] = 4 ; // D
1575+ } else
1576+ return SYCLGenError ();
1577+ } else
1578+ return SYCLGenError ();
1579+ } else if (Inst->hasAttr (InstAttr::m16n8k128)) {
1580+ M = " 16" ;
1581+ N = " 8" ;
1582+ K = " 128" ;
1583+ // Only b1 type is supported for A and B matrices of m16n8k128
1584+ if (AType->getKind () == InlineAsmBuiltinType::b1) {
1585+ // If A matrix type is b1, then C&D matrix types can only be s32
1586+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1587+ NumVecElements[0 ] = 2 ; // A
1588+ NumVecElements[1 ] = 1 ; // B
1589+ NumVecElements[2 ] = 4 ; // C
1590+ NumVecElements[3 ] = 4 ; // D
1591+
1592+ // Only and/xor bitwise operations are supported for m16n8k128
1593+ if (Inst->hasAttr (InstAttr::op_and))
1594+ MatrixOp = " and" ;
1595+ else if (Inst->hasAttr (InstAttr::op_xor))
1596+ MatrixOp = " xor" ;
1597+ else
1598+ return SYCLGenError ();
1599+ } else
1600+ return SYCLGenError ();
1601+ } else
1602+ return SYCLGenError ();
1603+ } else if (Inst->hasAttr (InstAttr::m16n8k256)) {
1604+ M = " 16" ;
1605+ N = " 8" ;
1606+ K = " 256" ;
1607+ // Only b1 type is supported for A and B matrices of m16n8k256
1608+ if (AType->getKind () == InlineAsmBuiltinType::b1) {
1609+ // If A matrix type is b1, then C&D matrix types can only be s32
1610+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1611+ NumVecElements[0 ] = 4 ; // A
1612+ NumVecElements[1 ] = 2 ; // B
1613+ NumVecElements[2 ] = 4 ; // C
1614+ NumVecElements[3 ] = 4 ; // D
1615+
1616+ // Only and/xor bitwise operations are supported for m16n8k256
1617+ if (Inst->hasAttr (InstAttr::op_and))
1618+ MatrixOp = " and" ;
1619+ else if (Inst->hasAttr (InstAttr::op_xor))
1620+ MatrixOp = " xor" ;
1621+ else
1622+ return SYCLGenError ();
13881623 } else
13891624 return SYCLGenError ();
13901625 } else
@@ -1407,7 +1642,12 @@ class SYCLGen : public SYCLGenBase {
14071642
14081643 MulType = ABType;
14091644 OS () << MapNames::getDpctNamespace () << " experimental::matrix::mma" ;
1410- OS () << " <" << MulType << " >(" ;
1645+ if (!MatrixOp.empty ()) {
1646+ OS () << " _" << MatrixOp;
1647+ }
1648+ OS () << " <" ;
1649+ OS () << M << " , " << N << " , " << K << " , " ;
1650+ OS () << MulType << " >(" ;
14111651
14121652 // Add D matrix address values to store the MAD result
14131653 for (unsigned Inst = 0 ; Inst != DMatVE->getNumElements (); ++Inst) {
@@ -1416,7 +1656,8 @@ class SYCLGen : public SYCLGenBase {
14161656 OS () << " &" ;
14171657 if (emitStmt (DMatVE->getElement (Inst)))
14181658 return SYCLGenError ();
1419- OS () << " , " ;
1659+ if ((Inst + 1 ) != DMatVE->getNumElements ())
1660+ OS () << " , " ;
14201661 }
14211662
14221663 // Add A, B & C matrix values to compute MAD
@@ -1427,16 +1668,15 @@ class SYCLGen : public SYCLGenBase {
14271668 for (unsigned Inst = 0 ; Inst != VE->getNumElements (); ++Inst) {
14281669 if (isa<InlineAsmDiscardExpr>(VE->getElement (Inst)))
14291670 continue ;
1671+ OS () << " , " ;
14301672 if (emitStmt (VE->getElement (Inst)))
14311673 return SYCLGenError ();
1432- OS () << " , " ;
14331674 }
14341675 } else {
14351676 return SYCLGenError ();
14361677 }
14371678 }
14381679
1439- OS () << DpctGlobalInfo::getItem (GAS);
14401680 OS () << " );" ;
14411681
14421682 const auto *KernelDecl = getImmediateOuterFuncDecl (GAS);
0 commit comments