@@ -621,6 +621,7 @@ int main(int argc, char* argv[])
621621 bool has_gemm = false ;
622622 bool has_brgemm = false ;
623623 bool has_matmul = false ;
624+ bool has_naive_matmul = false ;
624625 bool has_unary = false ;
625626 bool has_tensor_operations = false ;
626627 bool has_shared_tensor_operations = false ;
@@ -629,6 +630,7 @@ int main(int argc, char* argv[])
629630 bool has_opt_einsum_benchmark = false ;
630631 bool has_reciprocal = false ;
631632 bool has_sigmoid = false ;
633+ bool has_naive_sigmoid = false ;
632634 for (int i = 1 ; i < argc; ++i)
633635 {
634636 if (strcmp (argv[i], " gemm" ) == 0 )
@@ -637,6 +639,8 @@ int main(int argc, char* argv[])
637639 has_brgemm = true ;
638640 else if (strcmp (argv[i], " matmul" ) == 0 )
639641 has_matmul = true ;
642+ else if (strcmp (argv[i], " naive-matmul" ) == 0 )
643+ has_naive_matmul = true ;
640644 else if (strcmp (argv[i], " unary" ) == 0 )
641645 has_unary = true ;
642646 else if (strcmp (argv[i], " top" ) == 0 )
@@ -653,6 +657,8 @@ int main(int argc, char* argv[])
653657 has_reciprocal = true ;
654658 else if (strcmp (argv[i], " sigmoid" ) == 0 )
655659 has_sigmoid = true ;
660+ else if (strcmp (argv[i], " naive-sigmoid" ) == 0 )
661+ has_naive_sigmoid = true ;
656662 else if (strcmp (argv[i], " help" ) == 0 )
657663 std::cout << " Usage: " << argv[0 ] << " [gemm|brgemm|matmul|unary|top|top-shared|top-opt|einsum|opt-einsum|reciprocal|sigmoid]" << std::endl;
658664 else
@@ -683,6 +689,16 @@ int main(int argc, char* argv[])
683689 matmul_bm.close ();
684690 }
685691
692+ if (has_naive_matmul)
693+ {
694+ mini_jit::benchmarks::NaiveMatmulMNKBench bench_mnk (3.0 , 2048 , 2048 , 2048 );
695+ mini_jit::benchmarks::NaiveMatmulBrMNKBench bench_brmnk (3.0 , 1024 , 1024 , 1024 , 16 );
696+ std::ofstream matmul_bm (" benchmarks/naive_matmul_benchmarks.txt" );
697+ print_throughput (bench_mnk, matmul_bm, " NaiveMatmulMNKBench 2048x2048x2048" );
698+ print_throughput (bench_brmnk, matmul_bm, " NaiveMatmulBrMNKBench 1024x1024x1024 br=16" );
699+ matmul_bm.close ();
700+ }
701+
686702 if (has_unary)
687703 {
688704 const double RUN_TIME = 3.0 ;
@@ -1024,5 +1040,22 @@ int main(int argc, char* argv[])
10241040 sigmoid_bm.close ();
10251041 }
10261042
1043+ if (has_naive_sigmoid)
1044+ {
1045+ const double RUN_TIME = 3.0 ;
1046+ std::ofstream sigmoid_bm (" benchmarks/naive_sigmoid_benchmark.txt" );
1047+
1048+ mini_jit::benchmarks::NaiveSigmoidPrimitiveBench bench_naive_sigmoid_50_50 (RUN_TIME, 50 , 50 );
1049+ mini_jit::benchmarks::NaiveSigmoidPrimitiveBench bench_naive_sigmoid_64_64 (RUN_TIME, 64 , 64 );
1050+ mini_jit::benchmarks::NaiveSigmoidPrimitiveBench bench_naive_sigmoid_512_512 (RUN_TIME, 512 , 512 );
1051+ mini_jit::benchmarks::NaiveSigmoidPrimitiveBench bench_naive_sigmoid_2048_2048 (RUN_TIME, 2048 , 2048 );
1052+ print_bandwidth (bench_naive_sigmoid_50_50, sigmoid_bm, " NaiveSigmoidPrimitiveBench 50x50" );
1053+ print_bandwidth (bench_naive_sigmoid_64_64, sigmoid_bm, " NaiveSigmoidPrimitiveBench 64x64" );
1054+ print_bandwidth (bench_naive_sigmoid_512_512, sigmoid_bm, " NaiveSigmoidPrimitiveBench 512x512" );
1055+ print_bandwidth (bench_naive_sigmoid_2048_2048, sigmoid_bm, " NaiveSigmoidPrimitiveBench 2048x2048" );
1056+
1057+ sigmoid_bm.close ();
1058+ }
1059+
10271060 return 0 ;
10281061}
0 commit comments