@@ -35,9 +35,8 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
3535 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
3636 isSigned = !intType.isUnsigned ();
3737 storageTypeWidth = intType.getWidth ();
38- } else if (llvm::dyn_cast<Float8E5M2Type>(type) ||
39- llvm::dyn_cast<Float8E4M3FNType>(type)) {
40- storageTypeWidth = 8 ;
38+ } else if (mlir::isa<Float8E5M2Type, Float8E4M3FNType, Float4E2M1FNType>(type)) {
39+ storageTypeWidth = llvm::dyn_cast<FloatType>(type).getWidth ();
4140 isSigned = true ;
4241 } else {
4342 parser.emitError (typeLoc, " illegal quantized storage type alias" );
@@ -132,12 +131,15 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
132131 const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
133132 defaultMin = QuantizedType::getDefaultMinimumForInteger (isSigned, width);
134133 defaultMax = QuantizedType::getDefaultMaximumForInteger (isSigned, width);
135- } else if (storageType. isa <Float8E5M2Type>()) {
134+ } else if (mlir:: isa<Float8E5M2Type>(storageType )) {
136135 defaultMin = QuantizedType::getDefaultMinimumForF8E5M2 ();
137136 defaultMax = QuantizedType::getDefaultMaximumForF8E5M2 ();
138- } else if (storageType. isa <Float8E4M3FNType>()) {
137+ } else if (mlir:: isa<Float8E4M3FNType>(storageType )) {
139138 defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN ();
140139 defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN ();
140+ } else if (mlir::isa<Float4E2M1FNType>(storageType)) {
141+ defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN ();
142+ defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN ();
141143 } else {
142144 defaultMin = std::numeric_limits<int64_t >::max ();
143145 defaultMax = std::numeric_limits<int64_t >::min ();
@@ -150,7 +152,7 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
150152 }
151153
152154 // Explicit storage min and storage max.
153- // F8 min and max values are integers, so parseInteger() is used.
155+ // F8 and F4 min and max values are integers, so parseInteger() is used.
154156 SMLoc minLoc = parser.getCurrentLocation (), maxLoc;
155157 if (parser.parseInteger (storageTypeMin) || parser.parseColon () ||
156158 parser.getCurrentLocation (&maxLoc) ||
@@ -382,7 +384,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
382384// / block-size-info `,` scale-zero-tensor `>`
383385// / storage-spec ::= storage-type (`<` storage-range `>`)?
384386// / storage-range ::= integer-literal `:` integer-literal
385- // / storage-type ::= (`i` | `u`) integer-literal
387+ // / storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN`
386388// / expressed-type-spec ::= `:` `f` integer-literal
387389// / axis-spec ::= `:` integer-literal
388390// / scale-zero ::= scale (`:` zero-point)?
@@ -407,9 +409,9 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
407409// / scale-zero-list `>`
408410// / storage-spec ::= storage-type (`<` storage-range `>`)?
409411// / storage-range ::= integer-literal `:` integer-literal
410- // / storage-type ::= (`i` | `u`) integer-literal
412+ // / storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN`
411413// / quantile-type-spec ::= `:` ((`i` | `u` | `f`) integer-literal | `f8E5M2` |
412- // / `f8E4M3FN`)
414+ // / `f8E4M3FN` | `f4E2M1FN` )
413415// / expressed-type-spec ::= `:` `f` integer-literal axis-spec ::=
414416// / `:` integer-literal quantiles-list ::= `{` quantile (`,` quantile)* `}`
415417// / scale-zero ::= `:` float-literal `:` integer-literal
@@ -641,6 +643,8 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
641643 out << " f8E5M2" ;
642644 } else if (type.getStorageType ().isa <Float8E4M3FNType>()) {
643645 out << " f8E4M3FN" ;
646+ } else if (type.getStorageType ().isa <Float4E2M1FNType>()) {
647+ out << " f4E2M1FN" ;
644648 } else if (isSigned) {
645649 out << " i" << storageWidth;
646650 } else {
@@ -655,7 +659,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
655659 ? QuantizedType::getDefaultMinimumForF8E5M2 ()
656660 : type.getStorageType ().isa <Float8E4M3FNType>()
657661 ? QuantizedType::getDefaultMinimumForF8E4M3FN ()
658- : std::numeric_limits<int64_t >::max ();
662+ : type.getStorageType ().isa <Float4E2M1FNType>()
663+ ? QuantizedType::getDefaultMinimumForF4E2M1FN ()
664+ : std::numeric_limits<int64_t >::max ();
659665
660666 int64_t defaultMax =
661667 type.getStorageType ().isa <IntegerType>()
@@ -664,7 +670,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
664670 ? QuantizedType::getDefaultMaximumForF8E5M2 ()
665671 : type.getStorageType ().isa <Float8E4M3FNType>()
666672 ? QuantizedType::getDefaultMaximumForF8E4M3FN ()
667- : std::numeric_limits<int64_t >::min ();
673+ : type.getStorageType ().isa <Float4E2M1FNType>()
674+ ? QuantizedType::getDefaultMaximumForF4E2M1FN ()
675+ : std::numeric_limits<int64_t >::min ();
668676
669677 if (defaultMin != type.getStorageTypeMin () ||
670678 defaultMax != type.getStorageTypeMax ()) {
@@ -685,6 +693,8 @@ static void printQuantileType(Type quantileType, DialectAsmPrinter &out) {
685693 out << " :f8E5M2" ;
686694 } else if (quantileType.isa <Float8E4M3FNType>()) {
687695 out << " :f8E4M3FN" ;
696+ } else if (quantileType.isa <Float4E2M1FNType>()) {
697+ out << " :f4E2M1FN" ;
688698 } else {
689699 // Float types
690700 out << " :" << quantileType;
0 commit comments