@@ -757,36 +757,74 @@ static const std::string gpu_pipeline =
757757 " func.func(convert-parallel-loops-to-gpu),"
758758 // insert-gpu-allocs pass can have client-api = opencl or vulkan args
759759 " func.func(insert-gpu-allocs{in-regions=1}),"
760- // ** imex GPU passes
761- // "drop-regions,"
762- // "canonicalize,"
763- // // "normalize-memrefs,"
764- // // "gpu-decompose-memrefs,"
765- // "func.func(lower-affine),"
766- // "gpu-kernel-outlining,"
767- // "canonicalize,"
768- // "cse,"
769- // // The following set-spirv-* passes can have client-api = opencl or
770- // vulkan
771- // // args
772- // "set-spirv-capabilities{client-api=opencl},"
773- // "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
774- // "canonicalize,"
775- // "fold-memref-alias-ops,"
776- // "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
777- // "spirv.module(spirv-lower-abi-attrs),"
778- // "spirv.module(spirv-update-vce),"
779- // // "func.func(llvm-request-c-wrappers),"
780- // "serialize-spirv,"
781- // "expand-strided-metadata,"
782- // "lower-affine,"
783- // "convert-gpu-to-gpux,"
784- // "convert-func-to-llvm,"
785- // "convert-math-to-llvm,"
786- // "convert-gpux-to-llvm,"
787- // "finalize-memref-to-llvm,"
788- // "reconcile-unrealized-casts";
789- // ** nv GPU passes
760+ " drop-regions,"
761+ " canonicalize,"
762+ // "normalize-memrefs,"
763+ // "gpu-decompose-memrefs,"
764+ " func.func(lower-affine),"
765+ " gpu-kernel-outlining,"
766+ " canonicalize,"
767+ " cse,"
768+ // The following set-spirv-* passes can have client-api = opencl or vulkan
769+ // args
770+ " set-spirv-capabilities{client-api=opencl},"
771+ " gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
772+ " canonicalize,"
773+ " fold-memref-alias-ops,"
774+ " imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
775+ " spirv.module(spirv-lower-abi-attrs),"
776+ " spirv.module(spirv-update-vce),"
777+ // "func.func(llvm-request-c-wrappers),"
778+ " serialize-spirv,"
779+ " expand-strided-metadata,"
780+ " lower-affine,"
781+ " convert-gpu-to-gpux,"
782+ " convert-func-to-llvm,"
783+ " convert-math-to-llvm,"
784+ " convert-gpux-to-llvm,"
785+ " finalize-memref-to-llvm,"
786+ " reconcile-unrealized-casts" ;
787+
788+ static const std::string cuda_pipeline =
789+ " add-gpu-regions,"
790+ " canonicalize,"
791+ " ndarray-dist,"
792+ " func.func(dist-coalesce),"
793+ " func.func(dist-infer-elementwise-cores),"
794+ " convert-dist-to-standard,"
795+ " canonicalize,"
796+ " overlap-comm-and-compute,"
797+ " add-comm-cache-keys,"
798+ " lower-distruntime-to-idtr,"
799+ " convert-ndarray-to-linalg,"
800+ " canonicalize,"
801+ " func.func(tosa-make-broadcastable),"
802+ " func.func(tosa-to-linalg),"
803+ " func.func(tosa-to-tensor),"
804+ " canonicalize,"
805+ " linalg-fuse-elementwise-ops,"
806+ " arith-expand,"
807+ " memref-expand,"
808+ " arith-bufferize,"
809+ " func-bufferize,"
810+ " func.func(empty-tensor-to-alloc-tensor),"
811+ " func.func(scf-bufferize),"
812+ " func.func(tensor-bufferize),"
813+ " func.func(bufferization-bufferize),"
814+ " func.func(linalg-bufferize),"
815+ " func.func(linalg-detensorize),"
816+ " func.func(tensor-bufferize),"
817+ " region-bufferize,"
818+ " canonicalize,"
819+ " func.func(finalizing-bufferize),"
820+ " imex-remove-temporaries,"
821+ " func.func(convert-linalg-to-parallel-loops),"
822+ " func.func(scf-parallel-loop-fusion),"
823+ // is add-outer-parallel-loop needed?
824+ " func.func(imex-add-outer-parallel-loop),"
825+ " func.func(gpu-map-parallel-loops),"
826+ " func.func(convert-parallel-loops-to-gpu),"
827+ " func.func(insert-gpu-allocs{in-regions=1}),"
790828 " func.func(insert-gpu-copy),"
791829 " drop-regions,"
792830 " canonicalize,"
@@ -808,7 +846,9 @@ static const std::string gpu_pipeline =
808846
809847const std::string _passes (get_text_env (" SHARPY_PASSES" ));
810848static const std::string &pass_pipeline =
811- _passes != " " ? _passes : (useGPU () ? gpu_pipeline : cpu_pipeline);
849+ _passes != " " ? _passes
850+ : (useGPU () ? (useCUDA () ? cuda_pipeline : gpu_pipeline)
851+ : cpu_pipeline);
812852
813853JIT::JIT (const std::string &libidtr)
814854 : _context (::mlir::MLIRContext::Threading::DISABLED), _pm (&_context),
@@ -860,23 +900,24 @@ JIT::JIT(const std::string &libidtr)
860900 _crunnerlib = mlirRoot + " /lib/libmlir_c_runner_utils.so" ;
861901 _runnerlib = mlirRoot + " /lib/libmlir_runner_utils.so" ;
862902 if (!std::ifstream (_crunnerlib)) {
863- throw std::runtime_error (" Cannot find libmlir_c_runner_utils.so " );
903+ throw std::runtime_error (" Cannot find lib: " + _crunnerlib );
864904 }
865905 if (!std::ifstream (_runnerlib)) {
866- throw std::runtime_error (" Cannot find libmlir_runner_utils.so " );
906+ throw std::runtime_error (" Cannot find lib: " + _runnerlib );
867907 }
868908
869909 if (useGPU ()) {
870910 auto gpuxlibstr = get_text_env (" SHARPY_GPUX_SO" );
871911 if (!gpuxlibstr.empty ()) {
872912 _gpulib = std::string (gpuxlibstr);
873913 } else {
874- // auto imexRoot = get_text_env("IMEXROOT");
875- // imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
876- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
877- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
878- // for nv gpu
879- _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
914+ if (useCUDA ()) {
915+ _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
916+ } else {
917+ auto imexRoot = get_text_env (" IMEXROOT" );
918+ imexRoot = !imexRoot.empty () ? imexRoot : std::string (CMAKE_IMEX_ROOT);
919+ _gpulib = imexRoot + " /lib/liblevel-zero-runtime.so" ;
920+ }
880921 if (!std::ifstream (_gpulib)) {
881922 throw std::runtime_error (" Cannot find lib: " + _gpulib);
882923 }
0 commit comments