Skip to content

Commit 708bba0

Browse files
Enable quantization for FP4 type
1 parent b9ac6f5 commit 708bba0

File tree

7 files changed

+106
-33
lines changed

7 files changed

+106
-33
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ class QuantizedType : public Type {
9797
return -getDefaultMaximumForF8E5M2();
9898
}
9999

100+
static constexpr int64_t getDefaultMaximumForF4E2M1FN() { return 6; }
101+
102+
static constexpr int64_t getDefaultMinimumForF4E2M1FN() {
103+
return -getDefaultMaximumForF4E2M1FN();
104+
}
105+
100106
/// Gets the original expressed type that this quantized type approximates.
101107
/// Note that this presumes that the quantized type was always derived from
102108
/// a floating point type, which in the broadest definition, is not true (i.e.
@@ -267,7 +273,7 @@ class AnyQuantizedType
267273
/// Per-layer, optional parameters omitted:
268274
/// !quant<uniform[StorageType]{Scale}>
269275
///
270-
/// StorageType: 'i'|'u' NumBits
276+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
271277
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
272278
/// Scale: A legal double value
273279
/// ZeroPoint: An integer value
@@ -327,7 +333,7 @@ class UniformQuantizedType
327333
/// Per-axis, optional parameters omitted:
328334
/// !quant<uniform[StorageType]{Scale}>
329335
///
330-
/// StorageType: 'i'|'u' NumBits
336+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
331337
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
332338
/// QuantizedDim: An integer value
333339
/// QuantParams: (Scale ':' ZeroPoint)+
@@ -414,7 +420,7 @@ class UniformQuantizedPerAxisType
414420
/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
415421
/// ScaleZero ::= Scale (':' ZeroPoint)?
416422
///
417-
/// StorageType: 'i'|'u' NumBits
423+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
418424
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
419425
/// AxisSpec: An integer value
420426
/// BlockSizeSpec: An integer value
@@ -533,16 +539,16 @@ class UniformQuantizedSubChannelType
533539

534540
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
535541
/// look up table array of quantile values. The type of the data in the look up table is determined by
536-
/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
542+
/// the quantileType member: supported quantileType types are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64.
537543
///
538544
/// Syntax synopsis:
539545
/// Per-layer, all parameters expressed:
540546
/// !quant<quantile[StorageType:QuantileType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
541547
/// Per-layer, optional parameters omitted:
542548
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
543549
///
544-
/// StorageType: 'i'|'u' NumBits
545-
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
550+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
551+
/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
546552
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
547553
/// Quantiles: Quantile+
548554
/// Quantile: A legal double value
@@ -600,16 +606,16 @@ class QuantileQuantizedType
600606

601607
/// Represents per-axis QuantileQuantizedType (also known as per-channel
602608
/// quantization). The type of the data in the look up table is determined by the
603-
/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
609+
/// quantileType member: supported quantileType types are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64.
604610
///
605611
/// Syntax synopsis:
606612
/// Per-axis, all parameters expressed:
607613
/// !quant<quantile[StorageType:QuantileType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
608614
/// Per-axis, optional parameters omitted:
609615
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
610616
///
611-
/// StorageType: 'i'|'u' NumBits
612-
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
617+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
618+
/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
613619
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
614620
/// QuantizedDim: An integer value
615621
/// Quantiles: Quantile+

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,18 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
6767
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
6868
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
6969
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
70-
} else if (storageType.isa<Float8E5M2Type>()) {
70+
} else if (mlir::isa<Float8E5M2Type>(storageType)) {
7171
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
7272
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
73-
} else if (storageType.isa<Float8E4M3FNType>()) {
73+
} else if (mlir::isa<Float8E4M3FNType>(storageType)) {
7474
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
7575
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
76+
} else if (mlir::isa<Float4E2M1FNType>(storageType)) {
77+
defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN();
78+
defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN();
7679
} else {
7780
return emitError() << "illegal storage type, supported types are: integral "
78-
"types, Float8E4M3FNType and Float8E5M2Type ";
81+
"types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType ";
7982
}
8083

8184
// Verify storageTypeMin and storageTypeMax.
@@ -574,19 +577,18 @@ LogicalResult QuantileQuantizedType::verifyInvariants(
574577
unsigned typeWidth{};
575578
if (storageType.isa<IntegerType>()) {
576579
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
577-
} else if (storageType.isa<Float8E5M2Type>() ||
578-
storageType.isa<Float8E4M3FNType>()) {
579-
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
580+
} else if (mlir::isa<Float8E5M2Type, Float8E4M3FNType, Float4E2M1FNType>(storageType)) {
581+
// Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from FloatType.
580582
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
581583
} else {
582584
return emitError() << "illegal storage type, supported types are: integral "
583-
"types, Float8E4M3FNType and Float8E5M2Type ";
585+
"types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType ";
584586
}
585587

586588
const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
587589
const size_t typeWidthSize = 1 << typeWidth;
588590
const size_t expectedSize =
589-
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
591+
(storageTypeRange < typeWidthSize) && !mlir::isa<FloatType>(storageType) ? storageTypeRange : typeWidthSize;
590592

591593
const auto quantileArraySize = quantiles.size();
592594
if (quantileArraySize != expectedSize) {
@@ -660,19 +662,18 @@ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
660662
unsigned typeWidth{};
661663
if (storageType.isa<IntegerType>()) {
662664
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
663-
} else if (storageType.isa<Float8E5M2Type>() ||
664-
storageType.isa<Float8E4M3FNType>()) {
665-
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
665+
} else if (mlir::isa<Float8E5M2Type, Float8E4M3FNType, Float4E2M1FNType>(storageType)) {
666+
// Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from FloatType.
666667
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
667668
} else {
668669
return emitError() << "illegal storage type, supported types are: integral "
669-
"types, Float8E4M3FNType and Float8E5M2Type ";
670+
"types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType ";
670671
}
671672

672673
const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
673674
const size_t typeWidthSize = 1 << typeWidth;
674675
const size_t expectedSize =
675-
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
676+
(storageTypeRange < typeWidthSize) && !mlir::isa<FloatType>(storageType) ? storageTypeRange : typeWidthSize;
676677

677678
const auto quantileArraySize = quantiles.size();
678679
if (quantileArraySize != expectedSize) {

mlir/lib/Dialect/Quant/IR/TypeParser.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
3535
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
3636
isSigned = !intType.isUnsigned();
3737
storageTypeWidth = intType.getWidth();
38-
} else if (llvm::dyn_cast<Float8E5M2Type>(type) ||
39-
llvm::dyn_cast<Float8E4M3FNType>(type)) {
40-
storageTypeWidth = 8;
38+
} else if (mlir::isa<Float8E5M2Type, Float8E4M3FNType, Float4E2M1FNType>(type)) {
39+
storageTypeWidth = llvm::dyn_cast<FloatType>(type).getWidth();
4140
isSigned = true;
4241
} else {
4342
parser.emitError(typeLoc, "illegal quantized storage type alias");
@@ -132,12 +131,15 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
132131
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
133132
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
134133
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
135-
} else if (storageType.isa<Float8E5M2Type>()) {
134+
} else if (mlir::isa<Float8E5M2Type>(storageType)) {
136135
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
137136
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
138-
} else if (storageType.isa<Float8E4M3FNType>()) {
137+
} else if (mlir::isa<Float8E4M3FNType>(storageType)) {
139138
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
140139
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
140+
} else if (mlir::isa<Float4E2M1FNType>(storageType)) {
141+
defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN();
142+
defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN();
141143
} else {
142144
defaultMin = std::numeric_limits<int64_t>::max();
143145
defaultMax = std::numeric_limits<int64_t>::min();
@@ -150,7 +152,7 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
150152
}
151153

152154
// Explicit storage min and storage max.
153-
// F8 min and max values are integers, so parseInteger() is used.
155+
// F8 and F4 min and max values are integers, so parseInteger() is used.
154156
SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
155157
if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
156158
parser.getCurrentLocation(&maxLoc) ||
@@ -382,7 +384,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
382384
/// block-size-info `,` scale-zero-tensor `>`
383385
/// storage-spec ::= storage-type (`<` storage-range `>`)?
384386
/// storage-range ::= integer-literal `:` integer-literal
385-
/// storage-type ::= (`i` | `u`) integer-literal
387+
/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN`
386388
/// expressed-type-spec ::= `:` `f` integer-literal
387389
/// axis-spec ::= `:` integer-literal
388390
/// scale-zero ::= scale (`:` zero-point)?
@@ -407,9 +409,9 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
407409
/// scale-zero-list `>`
408410
/// storage-spec ::= storage-type (`<` storage-range `>`)?
409411
/// storage-range ::= integer-literal `:` integer-literal
410-
/// storage-type ::= (`i` | `u`) integer-literal
412+
/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN`
411413
/// quantile-type-spec ::= `:` ((`i` | `u` | `f`) integer-literal | `f8E5M2` |
412-
/// `f8E4M3FN`)
414+
/// `f8E4M3FN` | `f4E2M1FN`)
413415
/// expressed-type-spec ::= `:` `f` integer-literal axis-spec ::=
414416
/// `:` integer-literal quantiles-list ::= `{` quantile (`,` quantile)* `}`
415417
/// scale-zero ::= `:` float-literal `:` integer-literal
@@ -641,6 +643,8 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
641643
out << "f8E5M2";
642644
} else if (type.getStorageType().isa<Float8E4M3FNType>()) {
643645
out << "f8E4M3FN";
646+
} else if (type.getStorageType().isa<Float4E2M1FNType>()) {
647+
out << "f4E2M1FN";
644648
} else if (isSigned) {
645649
out << "i" << storageWidth;
646650
} else {
@@ -655,7 +659,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
655659
? QuantizedType::getDefaultMinimumForF8E5M2()
656660
: type.getStorageType().isa<Float8E4M3FNType>()
657661
? QuantizedType::getDefaultMinimumForF8E4M3FN()
658-
: std::numeric_limits<int64_t>::max();
662+
: type.getStorageType().isa<Float4E2M1FNType>()
663+
? QuantizedType::getDefaultMinimumForF4E2M1FN()
664+
: std::numeric_limits<int64_t>::max();
659665

660666
int64_t defaultMax =
661667
type.getStorageType().isa<IntegerType>()
@@ -664,7 +670,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
664670
? QuantizedType::getDefaultMaximumForF8E5M2()
665671
: type.getStorageType().isa<Float8E4M3FNType>()
666672
? QuantizedType::getDefaultMaximumForF8E4M3FN()
667-
: std::numeric_limits<int64_t>::min();
673+
: type.getStorageType().isa<Float4E2M1FNType>()
674+
? QuantizedType::getDefaultMaximumForF4E2M1FN()
675+
: std::numeric_limits<int64_t>::min();
668676

669677
if (defaultMin != type.getStorageTypeMin() ||
670678
defaultMax != type.getStorageTypeMax()) {
@@ -685,6 +693,8 @@ static void printQuantileType(Type quantileType, DialectAsmPrinter &out) {
685693
out << ":f8E5M2";
686694
} else if (quantileType.isa<Float8E4M3FNType>()) {
687695
out << ":f8E4M3FN";
696+
} else if (quantileType.isa<Float4E2M1FNType>()) {
697+
out << ":f4E2M1FN";
688698
} else {
689699
// Float types
690700
out << ":" << quantileType;

mlir/test/Dialect/Quant/parse-quantile-invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,16 @@ func.func @parse() -> !qalias {
126126
// expected-error@+1 {{illegal storage type minimum: -500}}
127127
!qalias = !quant.quantile<f8E4M3FN<-500:448>:f16:f32, {-1.0,1.0}:0.99872:127>
128128

129+
// -----
130+
// Illegal storage min/max: max > defaultMax
131+
// expected-error@+1 {{illegal storage type maximum: 10}}
132+
!qalias = !quant.quantile<f4E2M1FN<-6:10>:f16:f32, {-1.0,1.0}:0.99872:127>
133+
134+
// -----
135+
// Illegal storage min/max: min < defaultMin
136+
// expected-error@+1 {{illegal storage type minimum: -10}}
137+
!qalias = !quant.quantile<f4E2M1FN<-10:6>:f16:f32, {-1.0,1.0}:0.99872:127>
138+
129139
// -----
130140
// Illegal uniform params: invalid scale
131141
// expected-error@+1 {{expected floating point literal}}

mlir/test/Dialect/Quant/parse-quantile.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ func.func @parse() -> !qalias {
4646
return %0 : !qalias
4747
}
4848

49+
// -----
50+
// Default min/max value optimization for f4E2M1FN.
51+
// CHECK: !quant.quantile<f4E2M1FN:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:9.987200e-01:127>
52+
!qalias = !quant.quantile<f4E2M1FN<-6:6>:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 >
53+
func.func @parse() -> !qalias {
54+
%0 = "foo"() : () -> !qalias
55+
return %0 : !qalias
56+
}
57+
4958
// -----
5059
// Required per-layer params specified:
5160
// [unsigned] storageType, expressedType, scale
@@ -92,6 +101,15 @@ func.func @parse() -> !qalias {
92101
return %0 : !qalias
93102
}
94103

104+
// -----
105+
// Storage type: f4E2M1FN
106+
// CHECK: !quant.quantile<f4E2M1FN:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:2.000000e+02>
107+
!qalias = !quant.quantile<f4E2M1FN:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:2.0e+2>
108+
func.func @parse() -> !qalias {
109+
%0 = "foo"() : () -> !qalias
110+
return %0 : !qalias
111+
}
112+
95113
// -----
96114
// Expressed type: f32
97115
// CHECK: !quant.quantile<u4:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:2.000000e+02>

mlir/test/Dialect/Quant/parse-uniform-invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,16 @@
100100
// expected-error@+1 {{illegal storage type minimum: -500}}
101101
!qalias = !quant.uniform<f8E4M3FN<-500:448>:f32, 0.99872:127>
102102

103+
// -----
104+
// Illegal storage min/max: max > defaultMax
105+
// expected-error@+1 {{illegal storage type maximum: 10}}
106+
!qalias = !quant.uniform<f4E2M1FN<-6:10>:f32, 0.99872:127>
107+
108+
// -----
109+
// Illegal storage min/max: min < defaultMin
110+
// expected-error@+1 {{illegal storage type minimum: -10}}
111+
!qalias = !quant.uniform<f4E2M1FN<-10:6>:f32, 0.99872:127>
112+
103113
// -----
104114
// Illegal uniform params: invalid scale
105115
// expected-error@+1 {{expected floating point literal}}

mlir/test/Dialect/Quant/parse-uniform.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ func.func @parse() -> !qalias {
4646
return %0 : !qalias
4747
}
4848

49+
// -----
50+
// Default min/max value optimization for f4E2M1FN.
51+
// CHECK: !quant.uniform<f4E2M1FN:f32, 9.987200e-01:127>
52+
!qalias = !quant.uniform<f4E2M1FN<-6:6>:f32, 0.99872:127 >
53+
func.func @parse() -> !qalias {
54+
%0 = "foo"() : () -> !qalias
55+
return %0 : !qalias
56+
}
57+
4958
// -----
5059
// Required per-layer params specified:
5160
// [unsigned] storageType, expressedType, scale
@@ -92,6 +101,15 @@ func.func @parse() -> !qalias {
92101
return %0 : !qalias
93102
}
94103

104+
// -----
105+
// Storage type: f4E2M1FN
106+
// CHECK: !quant.uniform<f4E2M1FN:f32, 2.000000e+02>
107+
!qalias = !quant.uniform<f4E2M1FN:f32, 2.0e+2>
108+
func.func @parse() -> !qalias {
109+
%0 = "foo"() : () -> !qalias
110+
return %0 : !qalias
111+
}
112+
95113
// -----
96114
// Storage type: i16
97115
// CHECK: !quant.uniform<i16:f32, 2.000000e+02>

0 commit comments

Comments
 (0)