@@ -266,9 +266,185 @@ export const createMatMulNBitsProgramInfo = (
266266 } ;
267267} ;
268268
269+ // Currently, only support blockSize = 32.
270+ export const createMatMulNBitsBlockSize32ProgramInfo = (
271+ inputs : readonly TensorView [ ] ,
272+ attributes : MatMulNBitsAttributes ,
273+ ) : ProgramInfo => {
274+ const inputShape = inputs [ 0 ] . dims ;
275+ const aRank = inputShape . length ;
276+ const dimAOuter = inputShape [ aRank - 2 ] ;
277+ const dimInner = attributes . k ;
278+ const dimBOuter = attributes . n ;
279+ const batchDims = inputShape . slice ( 0 , aRank - 2 ) ;
280+ const batchSize = ShapeUtil . size ( batchDims ) ;
281+ const blobSize = inputs [ 1 ] . dims [ 2 ] ;
282+ const blobSizeInWords = blobSize / 4 ;
283+ const dataType = inputs [ 0 ] . dataType ;
284+ const aComponents = getMaxComponents ( attributes . k ) ;
285+ const bComponents = getMaxComponents ( blobSizeInWords ) ;
286+ const outputShape = batchDims . concat ( [ dimAOuter , dimBOuter ] ) ;
287+
288+ const workgroupSize = 128 ;
289+ const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1 ;
290+ const workgroupX = workgroupSize / workgroupY ;
291+ const tileSize = workgroupX * bComponents * 8 ; // each uint32 has 8 data.
292+ const aLengthPerTile = tileSize / aComponents ;
293+ const blocksPerTile = tileSize / attributes . blockSize ;
294+ const dispatchSize = ShapeUtil . size ( outputShape ) / workgroupY ;
295+
296+ const programUniforms : ProgramUniform [ ] = [ ] ;
297+ const inputShapeTemp = [ batchSize , dimAOuter , dimInner / aComponents ] ;
298+ const bShape = ShapeUtil . convertShape ( inputs [ 1 ] . dims ) . slice ( ) ;
299+ bShape . splice ( - 1 , 1 , blobSizeInWords / bComponents ) ;
300+ programUniforms . push ( ...createTensorShapeVariables ( inputShapeTemp ) ) ;
301+ programUniforms . push ( ...createTensorShapeVariables ( bShape ) ) ;
302+ programUniforms . push ( ...createTensorShapeVariables ( inputs [ 2 ] . dims ) ) ;
303+ if ( inputs . length === 4 ) {
304+ programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( inputs [ 3 ] . dims ) ) ) ;
305+ }
306+ const outputShapeTemp = [ batchSize , dimAOuter , dimBOuter ] ;
307+ programUniforms . push ( ...createTensorShapeVariables ( outputShapeTemp ) ) ;
308+
309+ const getShaderSource = ( shaderHelper : ShaderHelper ) => {
310+ const inputRank = inputShapeTemp . length ;
311+ const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputRank , aComponents ) ;
312+ const b = inputVariable ( 'b' , DataType . uint32 , bShape . length , bComponents ) ;
313+ const scales = inputVariable ( 'scales' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length ) ;
314+ const inputVariables = [ a , b , scales ] ;
315+ const zeroPoints =
316+ inputs . length === 4 ? inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims . length ) : undefined ;
317+ if ( zeroPoints ) {
318+ inputVariables . push ( zeroPoints ) ;
319+ }
320+ const outputRank = outputShapeTemp . length ;
321+ const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputRank ) ;
322+ const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
323+ const readA = ( ) => {
324+ switch ( aComponents ) {
325+ case 1 :
326+ return `
327+ let a_data0 = vec4<${ dataType } >(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);
328+ let a_data1 = vec4<${ dataType } >(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);` ;
329+ case 2 :
330+ return `
331+ let a_data0 = vec4<${ dataType } >(sub_a[word_offset], sub_a[word_offset + 1]);
332+ let a_data1 = vec4<${ dataType } >(sub_a[word_offset + 2], sub_a[word_offset + 3]);` ;
333+ case 4 :
334+ return `
335+ let a_data0 = sub_a[word_offset];
336+ let a_data1 = sub_a[word_offset + 1];` ;
337+ default :
338+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
339+ }
340+ } ;
341+
342+ return `
343+ var<workgroup> sub_a: array<${ a . type . value } , ${ aLengthPerTile } >;
344+ var<workgroup> inter_results: array<array<${ output . type . value } , ${ workgroupX } >, ${ workgroupY } >;
345+ ${ shaderHelper . declareVariables ( ...inputVariables , output ) }
346+ ${ shaderHelper . mainStart ( [ workgroupX , workgroupY , 1 ] ) }
347+ let output_indices = ${ output . offsetToIndices ( `workgroup_index * ${ workgroupY } ` ) } ;
348+ let col = output_indices[2];
349+ let row = output_indices[1];
350+ let batch = output_indices[0];
351+ let n_blocks_per_col = uniforms.b_shape[1];
352+ let num_tiles = (n_blocks_per_col - 1) / ${ blocksPerTile } + 1;
353+
354+ // Loop over shared dimension.
355+ for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
356+ let a_col_start = tile * ${ aLengthPerTile } ;
357+ // load one tile A data into shared memory.
358+ for (var a_offset = local_idx; a_offset < ${ aLengthPerTile } ; a_offset += ${ workgroupSize } )
359+ {
360+ let a_col = a_col_start + a_offset;
361+ if (a_col < uniforms.a_shape[2])
362+ {
363+ sub_a[a_offset] = ${ a . getByIndices ( `${ a . type . indices } (batch, row, a_col)` ) } ;
364+ } else {
365+ sub_a[a_offset] = ${ a . type . value } (0);
366+ }
367+ }
368+ workgroupBarrier();
369+
370+ // each thread process one block
371+ let b_row = col + local_id.y;
372+ let block = tile * ${ blocksPerTile } + local_id.x;
373+ ${
374+ zeroPoints
375+ ? `
376+ let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
377+ let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
378+ let zero_point_word_index = zero_point_byte_count >> 0x2u;
379+ let zero_point_byte_offset = zero_point_byte_count & 0x3u;
380+ let zero_point_nibble_offset: u32 = block & 0x1u;
381+ let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
382+ let zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_word_index' ) } >> zero_point_bits_offset;
383+ let zero_point = ${ dataType } ((zero_point_word) & 0xFu);`
384+ : `
385+ // The default zero point is 8 for unsigned 4-bit quantization.
386+ let zero_point = ${ dataType } (${ 8.0 } );`
387+ }
388+ let scale = ${ scales . getByOffset ( `b_row * n_blocks_per_col + block` ) } ;
389+ let b_data = ${ b . getByIndices ( `${ b . type . indices } (b_row, block, 0)` ) } ;
390+ var word_offset = local_id.x * ${ attributes . blockSize / aComponents } ;
391+ for (var i: u32 = 0; i < ${ bComponents } ; i++) {
392+ ${ readA ( ) }
393+ let b_value = ${ bComponents === 1 ? `b_data` : `b_data[i]` } ;
394+ let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
395+ let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
396+ let b_quantized_values = mat2x4<${ dataType } >(${ Array . from (
397+ { length : 4 } ,
398+ ( _ , i ) => `${ dataType } (b_value_lower[${ i } ]), ${ dataType } (b_value_upper[${ i } ])` ,
399+ ) . join ( ', ' ) } );
400+ let b_dequantized_values = (b_quantized_values - mat2x4<${ dataType } >(${ Array ( 8 ) . fill ( 'zero_point' ) . join ( ',' ) } )) * scale;
401+ inter_results[local_id.y][local_id.x] += ${ Array . from (
402+ { length : 2 } ,
403+ ( _ , i ) => `${ `dot(a_data${ i } , b_dequantized_values[${ i } ])` } ` ,
404+ ) . join ( ' + ' ) } ;
405+ word_offset += ${ 8 / aComponents } ;
406+ }
407+ workgroupBarrier();
408+ }
409+
410+ if (local_idx < ${ workgroupY } ) {
411+ var output_value: ${ output . type . value } = ${ output . type . value } (0);
412+ for (var b = 0u; b < ${ workgroupX } ; b++) {
413+ output_value += inter_results[local_idx][b];
414+ }
415+ if (col + local_idx < uniforms.output_shape[2])
416+ {
417+ ${ output . setByIndices ( `${ output . type . indices } (batch, row, col + local_idx)` , 'output_value' ) }
418+ }
419+ }
420+ }` ;
421+ } ;
422+ return {
423+ name : 'BlockwiseMatMulNBits32' ,
424+ shaderCache : {
425+ hint : `${ attributes . blockSize } ;${ aComponents } ;${ bComponents } ;${ workgroupX } ;${ workgroupY } ` ,
426+ inputDependencies : Array ( inputs . length ) . fill ( 'rank' ) ,
427+ } ,
428+ getRunData : ( ) => ( {
429+ outputs : [ { dims : outputShape , dataType } ] ,
430+ dispatchGroup : { x : dispatchSize } ,
431+ programUniforms,
432+ } ) ,
433+ getShaderSource,
434+ } ;
435+ } ;
436+
269437export const matMulNBits = ( context : ComputeContext , attributes : MatMulNBitsAttributes ) : void => {
270438 validateInputs ( context . inputs , attributes ) ;
271- context . compute ( createMatMulNBitsProgramInfo ( context . inputs , attributes ) ) ;
439+ if (
440+ attributes . blockSize === 32 &&
441+ context . adapterInfo . isVendor ( 'intel' ) &&
442+ context . adapterInfo . isArchitecture ( 'gen-12lp' )
443+ ) {
444+ context . compute ( createMatMulNBitsBlockSize32ProgramInfo ( context . inputs , attributes ) ) ;
445+ } else {
446+ context . compute ( createMatMulNBitsProgramInfo ( context . inputs , attributes ) ) ;
447+ }
272448} ;
273449
274450export const parseMatMulNBitsAttributes = ( attributes : Record < string , unknown > ) : MatMulNBitsAttributes =>
0 commit comments