diff --git a/.github/workflows/build_llvm.yml b/.github/workflows/build_llvm.yml index edfb419c..f18f5906 100644 --- a/.github/workflows/build_llvm.yml +++ b/.github/workflows/build_llvm.yml @@ -18,6 +18,11 @@ on: description: 'Run the build with a tmate session ONLY in case of failure' required: false default: false + release: + description: 'whether to release' + type: boolean + required: false + default: true pull_request: paths: - ".github/actions/setup_base" @@ -210,7 +215,7 @@ jobs: path: ${{ startsWith(matrix.os, 'windows') && 'D:\a\ccache.log' || '/tmp/ccache.log' }} - name: Release current commit - if: (!cancelled() && ((github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch')) + if: (!cancelled() && ((github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release))) uses: ncipollo/release-action@v1.12.0 with: artifacts: "*.tar.gz,wheelhouse/*.whl" @@ -252,6 +257,7 @@ jobs: wheel_version: ${{ needs.build.outputs.WHEEL_VERSION }} workflow_call: true workflow_caller_run_id: ${{ github.run_id }} + release: ${{ inputs.release }} call-build-eudsl: @@ -268,10 +274,11 @@ jobs: with: workflow_call: true workflow_caller_run_id: ${{ github.run_id }} + release: ${{ inputs.release }} call-deploy-pip-page: - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) needs: [build] diff --git a/.github/workflows/build_mlir_python_bindings_wheel.yml b/.github/workflows/build_mlir_python_bindings_wheel.yml index 0e42ed48..04625901 100644 --- a/.github/workflows/build_mlir_python_bindings_wheel.yml +++ b/.github/workflows/build_mlir_python_bindings_wheel.yml @@ -18,6 +18,11 @@ on: description: 'Run the build with a tmate session ONLY in case of failure' required: false default: false + release: + description: 'whether to release' + type: boolean + required: false + default: true workflow_call: inputs: wheel_version: @@ -35,6 +40,11 @@ on: type: string required: false default: '' + release: + description: 'whether to release' + type: boolean + required: false + default: true pull_request: branches: - main @@ -289,7 +299,7 @@ jobs: release-mlir-python-bindings: - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) needs: [build-mlir-python-bindings] @@ -426,7 +436,7 @@ jobs: name: build_artifact_python_bindings-ubuntu-wasm - name: Release current commit - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) uses: ncipollo/release-action@v1.12.0 with: artifacts: "wheelhouse/mlir_python_bindings*.whl" @@ -454,11 +464,12 @@ jobs: with: workflow_call: true workflow_caller_run_id: ${{ github.run_id }} + release: ${{ inputs.release }} call-deploy-pip-page: - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) needs: [release-mlir-python-bindings] diff --git a/.github/workflows/build_test_release_eudsl.yml b/.github/workflows/build_test_release_eudsl.yml index bb8bb808..101c39df 100644 --- a/.github/workflows/build_test_release_eudsl.yml +++ b/.github/workflows/build_test_release_eudsl.yml @@ -18,6 +18,11 @@ on: description: 'Run the build with a tmate session ONLY in case of failure' required: false default: false + release: + description: 'whether to release' + type: boolean + required: false + default: true workflow_call: inputs: workflow_call: @@ -30,6 +35,11 @@ on: type: string required: false default: '' + release: + description: 'whether to release' + type: boolean + required: false + default: true pull_request: branches: - main @@ -430,8 +440,10 @@ jobs: release-eudsl: - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) + needs: [build-eudsl] + runs-on: "ubuntu-22.04" permissions: @@ -473,7 +485,7 @@ jobs: call-deploy-pip-page: - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) needs: [release-eudsl] diff --git a/.github/workflows/build_test_release_eudsl_python_extras.yml b/.github/workflows/build_test_release_eudsl_python_extras.yml index 19a2e49b..7fc5ac9d 100644 --- a/.github/workflows/build_test_release_eudsl_python_extras.yml +++ b/.github/workflows/build_test_release_eudsl_python_extras.yml @@ -7,6 +7,12 @@ name: "Build, test, release eudsl-python-extras" on: workflow_dispatch: + inputs: + release: + description: 'whether to release' + type: boolean + required: false + default: true workflow_call: inputs: workflow_call: @@ -19,6 +25,11 @@ on: type: string required: false default: '' + release: + description: 'whether to release' + type: boolean + required: false + default: true pull_request: branches: - main @@ -186,9 +197,44 @@ jobs: python -m pytest projects/eudsl-python-extras/tests $IGNORE + - name: "Test examples" + run: | + + python projects/eudsl-python-extras/examples/flash_attention.py + python projects/eudsl-python-extras/examples/mwe.py + python projects/eudsl-python-extras/examples/rdna_matmul_opt.py + + if [[ $(python -c "print(__import__('sys').version_info >= (3, 13))") == "True" ]]; then + python projects/eudsl-python-extras/examples/cuda_matmul_opt.py + fi + + - name: Test jupyter notebooks + # sed: can't read C:\hostedtoolcache\windows\Python\3.12.10\x64/jupyter_client/runapp.py: No such file or directory + if: matrix.os != 'windows' + shell: bash + env: + BRANCH: ${{ github.head_ref || github.ref_name }} + run: | + + pip install -q jupyter + + sed -i.bak 's/OUTPUT_TIMEOUT = 10/OUTPUT_TIMEOUT = 1000/g' \ + $(python -c 'import site; print(site.getsitepackages()[0])')/jupyter_client/runapp.py + + jupyter execute projects/eudsl-python-extras/examples/mlir_python_extras.ipynb --output=mlir_python_extras_output + cat projects/eudsl-python-extras/examples/mlir_python_extras_output.ipynb | jq '.cells[].outputs | select(length > 0) | .[0] | .text' + jupyter execute projects/eudsl-python-extras/examples/vectorization_e2e.ipynb --output=vectorization_e2e_output + cat projects/eudsl-python-extras/examples/vectorization_e2e_output.ipynb | jq '.cells[].outputs | select(length > 0) | .[0] | .text' + + # TODO(max): build wheels with nv targets + # if [ ${{ matrix.os }} == 'ubuntu' ]; then + # jupyter execute projects/eudsl-python-extras/examples/cuda_e2e.ipynb --output=cuda_e2e_output + # cat projects/eudsl-python-extras/examples/cuda_e2e_output.ipynb | jq '.cells[].outputs | select(length > 0) | .[0] | .text' + # fi + release-eudsl-python-extras: - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) needs: [build-eudsl-python-extras] @@ -222,7 +268,7 @@ jobs: call-deploy-pip-page: - if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' + if: (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'workflow_dispatch' && inputs.release) needs: [release-eudsl-python-extras] diff --git a/projects/eudsl-python-extras/README.md b/projects/eudsl-python-extras/README.md new file mode 100644 index 00000000..3d4ff1f1 --- /dev/null +++ b/projects/eudsl-python-extras/README.md @@ -0,0 +1,158 @@ +# eudsl-python-extras + +The missing pieces (as far as boilerplate reduction goes) of the MLIR python bindings. + +* [TL;DR](#tl-dr) +* [5s Intro](#5s-intro) +* [Install](#install) +* [Examples/Demo](#examples-demo) + +## TL;DR + +Full example at [examples/mwe.py](examples/mwe.py) (i.e., go there if you want to copy-paste). + +Turn this + +```python +K = 10 +memref_i64 = T.memref(K, K, T.i64) + +@func +@canonicalize(using=scf) +def memfoo(A: memref_i64, B: memref_i64, C: memref_i64): + one = constant(1) + two = constant(2) + if one > two: + three = constant(3) + else: + for i in range(0, K): + for j in range(0, K): + C[i, j] = A[i, j] * B[i, j] +``` + +into this + +```mlir +func.func @memfoo(%arg0: memref<10x10xi64>, %arg1: memref<10x10xi64>, %arg2: memref<10x10xi64>) { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %0 = arith.cmpi ugt, %c1_i32, %c2_i32 : i32 + scf.if %0 { + %c3_i32 = arith.constant 3 : i32 + } else { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + scf.for %arg3 = %c0 to %c10 step %c1 { + scf.for %arg4 = %c0 to %c10 step %c1 { + %1 = memref.load %arg0[%arg3, %arg4] : memref<10x10xi64> + %2 = memref.load %arg1[%arg3, %arg4] : memref<10x10xi64> + %3 = arith.muli %1, %2 : i64 + memref.store %3, %arg2[%arg3, %arg4] : memref<10x10xi64> + } + } + } + return +} +``` + +then run it like this + +```python +module = backend.compile( + ctx.module, + kernel_name=memfoo.__name__, + pipeline=Pipeline().bufferize().lower_to_llvm(), +) + +A = np.random.randint(0, 10, (K, K)) +B = np.random.randint(0, 10, (K, K)) +C = np.zeros((K, K), dtype=int) + +backend.load(module).memfoo(A, B, C) +assert np.array_equal(A * B, C) +``` + +## 5s Intro + +This is **not a Python compiler**, but just a (hopefully) nice way to emit MLIR using python. + +The few main features/affordances: + +1. `region_op`s (like `@func` above) + \ +   + 1. These are decorators around ops (bindings for MLIR operations) that have regions (e.g., [in_parallel](https://github.com/llvm/eudsl/blob/fa4807b17a21a4808cc0a4a8a32e2da57f7e3100/projects/eudsl-python-extras/mlir/extras/dialects/scf.py#L134)). + They turn decorated functions, by executing them "eagerly", into an instance of such an op, e.g., + ```python + @func + def foo(x: T.i32): + return + ``` + becomes `func.func @foo(%arg0: i32) { }`; if the region carrying op produces a result, the identifier for the python function (`foo`) becomes the corresponding `ir.Value` of the result (if the op doesn't produce a result then the identifier becomes the corresponding `ir.OpView`). + \ + \ + This has been upstreamed to [mlir/python/mlir/extras/meta.py](https://github.com/llvm/llvm-project/blob/24038650d9ca5d66b07d3075afdebe81012ab1f2/mlir/python/mlir/extras/meta.py#L12) + \ +   +2. `@canonicalize` (like `@canonicalize(using=scf)` above) + \ +   + 1. These are decorators that **rewrite the python AST**. They transform a select few forms (basically only `if`s) into a more "canonical" form, in order to more easily map to MLIR. If that scares you, fear not; they are not essential and all target MLIR can still be mapped to without using them (by using the slightly more verbose `region_op`). + \ + \ + See [mlir.extras.ast.canonicalize](https://github.com/llvm/eudsl/blob/f0914c3b3c0e3ca774575aa6a0fba73e1ebb631f/projects/eudsl-python-extras/mlir/extras/ast/canonicalize.py) for details. + \ +   +3. `mlir/extras.types` (like `T.memref(K, K, T.i64)` above) + \ +   + 1. These are just convenient wrappers around upstream type constructors. Note, because MLIR types are uniqued to a `ir.Context`, these are all actually functions that return the type. + \ + \ + These have been upstreamed to [mlir/python/mlir/extras/types.py](https://github.com/llvm/llvm-project/blob/52b18b4e82d412a7d755e89591c6ebcc41c257a1/mlir/python/mlir/extras/types.py) + \ +   +4. `Pipeline()` + \ +   + 1. This is just a (generated) wrapper around available **upstream** passes; it can be used to build pass pipelines (by `str(Pipeline())`). It is mainly convenient with IDEs/editors that will tab-complete the available methods on the `Pipeline` class (which correspond to passes), Note, if your host bindings don't register some upstream passes, then this will generate "illegal" pass pipelines. + \ + \ + See [utils/generate_pass_pipeline.py](https://github.com/llvm/eudsl/blob/f0914c3b3c0e3ca774575aa6a0fba73e1ebb631f/projects/eudsl-python-extras/utils/generate_pass_pipeline.py) for details on generation + [mlir.extras.runtime.passes](https://github.com/llvm/eudsl/blob/4f599951786aedad96e5943993763dc9c5bfb8cd/projects/eudsl-python-extras/mlir/extras/runtime/passes.py) for the passes themselves. + \ +   + + + +Note, also, there are no docs (because ain't no one got time for that) but that shouldn't be a problem because the package is designed such that you can use/reuse only the pieces/parts you want/understand. +But, open an issue if something isn't clear. + + +## Install + +If you want to just get started/play around: + +```shell +$ pip install eudsl-python-extras -f https://llvm.github.io/eudsl +``` + +Alternatively, this [colab notebook](https://drive.google.com/file/d/1NAtf2Yxj_VVnzwn8u_kxtajfVzgbuWhi/view?usp=sharing) (which is the same as [examples/mlir_python_extras.ipynb](examples/mlir_python_extras.ipynb)) has a MWE if you don't want to install anything even. + +In reality, this package is meant to work in concert with "host bindings" (some distribution of the actual MLIR Python bindings). +Practically speaking that means you need to have *some* package installed that includes mlir python bindings. + +So that means the second line should be amended to + +```shell +$ EUDSL_PYTHON_EXTRAS_HOST_PACKAGE_PREFIX= \ + pip install eudsl-python-extras -f https://llvm.github.io/eudsl +``` + +where `YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the package prefix for your chosen host bindings. +**When in doubt about this prefix**, it is everything up until `ir` when you import your bindings, e.g., in `import torch_mlir.ir`, `torch_mlir` is the `HOST_MLIR_PYTHON_PACKAGE_PREFIX` for the torch-mlir bindings. + +## Examples/Demo + +Check [examples](examples) and [tests](tests) for a plethora of example code. \ No newline at end of file diff --git a/projects/eudsl-python-extras/examples/chip.py b/projects/eudsl-python-extras/examples/chip.py new file mode 100644 index 00000000..8c89737c --- /dev/null +++ b/projects/eudsl-python-extras/examples/chip.py @@ -0,0 +1,1156 @@ +# mypy: ignore-errors +# -*- coding: utf-8 -*- +# +# TARGET arch is: ['-D__HIP_PLATFORM_AMD__', '-I/opt/rocm/include', '-x', 'c++'] +# WORD_SIZE is: 8 +# POINTER_SIZE is: 8 +# LONGDOUBLE_SIZE is: 16 +# +import ctypes + + +class AsDictMixin: + + @classmethod + def as_dict(cls, self): + result = {} + if not isinstance(self, AsDictMixin): + # not a structure, assume it's already a python object + return self + if not hasattr(cls, "_fields_"): + return result + for field_tuple in cls._fields_: # noqa + field = field_tuple[0] + if field.startswith("PADDING_"): + continue + value = getattr(self, field) + type_ = type(value) + if hasattr(value, "_length_") and hasattr(value, "_type_"): + # array + if not hasattr(type_, "as_dict"): + value = [v for v in value] + else: + type_ = type_._type_ + value = [type_.as_dict(v) for v in value] + elif hasattr(value, "contents") and hasattr(value, "_type_"): + # pointer + try: + if not hasattr(type_, "as_dict"): + value = value.contents + else: + type_ = type_._type_ + value = type_.as_dict(value.contents) + except ValueError: + # nullptr + value = None + elif isinstance(value, AsDictMixin): + # other structure + value = type_.as_dict(value) + result[field] = value + return result + + +class Structure(ctypes.Structure, AsDictMixin): + + def __init__(self, *args, **kwds): + # We don't want to use positional arguments fill PADDING_* fields + + args = dict(zip(self.__class__._field_names_(), args)) + args.update(kwds) + super(Structure, self).__init__(**args) + + @classmethod + def _field_names_(cls): + if hasattr(cls, "_fields_"): + return (f[0] for f in cls._fields_ if not f[0].startswith("PADDING")) + else: + return () + + @classmethod + def get_type(cls, field): + for f in cls._fields_: + if f[0] == field: + return f[1] + return None + + @classmethod + def bind(cls, bound_fields): + fields = {} + for name, type_ in cls._fields_: + if hasattr(type_, "restype"): + if name in bound_fields: + if bound_fields[name] is None: + fields[name] = type_() + else: + # use a closure to capture the callback from the loop scope + fields[name] = type_( + (lambda callback: lambda *args: callback(*args))( + bound_fields[name] + ) + ) + del bound_fields[name] + else: + # default callback implementation (does nothing) + try: + default_ = type_(0).restype().value + except TypeError: + default_ = None + fields[name] = type_( + (lambda default_: lambda *args: default_)(default_) + ) + else: + # not a callback function, use default initialization + if name in bound_fields: + fields[name] = bound_fields[name] + del bound_fields[name] + else: + fields[name] = type_() + if len(bound_fields) != 0: + raise ValueError( + "Cannot bind the following unknown callback(s) {}.{}".format( + cls.__name__, bound_fields.keys() + ) + ) + return cls(**fields) + + +class Union(ctypes.Union, AsDictMixin): + pass + + +c_int128 = ctypes.c_ubyte * 16 +c_uint128 = c_int128 +void = None +if ctypes.sizeof(ctypes.c_longdouble) == 16: + c_long_double_t = ctypes.c_longdouble +else: + c_long_double_t = ctypes.c_ubyte * 16 + + +class FunctionFactoryStub: + + def __getattr__(self, _): + return ctypes.CFUNCTYPE(lambda y: y) + + +_libraries = {} + + +def string_cast(char_pointer, encoding="utf-8", errors="strict"): + value = ctypes.cast(char_pointer, ctypes.c_char_p).value + if value is not None and encoding is not None: + value = value.decode(encoding, errors=errors) + return value + + +def char_pointer_cast(string, encoding="utf-8"): + if encoding is not None: + try: + string = string.encode(encoding) + except AttributeError: + # In Python3, bytes has no encode attribute + pass + string = ctypes.c_char_p(string) + return ctypes.cast(string, ctypes.POINTER(ctypes.c_char)) + + +_libraries["libamdhip64.so"] = ctypes.cdll.LoadLibrary("libamdhip64.so") + +c__Ea_HIP_SUCCESS__enumvalues = { + 0: "HIP_SUCCESS", + 1: "HIP_ERROR_INVALID_VALUE", + 2: "HIP_ERROR_NOT_INITIALIZED", + 3: "HIP_ERROR_LAUNCH_OUT_OF_RESOURCES", +} +HIP_SUCCESS = 0 +HIP_ERROR_INVALID_VALUE = 1 +HIP_ERROR_NOT_INITIALIZED = 2 +HIP_ERROR_LAUNCH_OUT_OF_RESOURCES = 3 +c__Ea_HIP_SUCCESS = ctypes.c_uint32 # enum + + +class struct_c__SA_hipDeviceArch_t(Structure): + pass + + +struct_c__SA_hipDeviceArch_t._pack_ = 1 # source:False +struct_c__SA_hipDeviceArch_t._fields_ = [ + ("hasGlobalInt32Atomics", ctypes.c_uint32, 1), + ("hasGlobalFloatAtomicExch", ctypes.c_uint32, 1), + ("hasSharedInt32Atomics", ctypes.c_uint32, 1), + ("hasSharedFloatAtomicExch", ctypes.c_uint32, 1), + ("hasFloatAtomicAdd", ctypes.c_uint32, 1), + ("hasGlobalInt64Atomics", ctypes.c_uint32, 1), + ("hasSharedInt64Atomics", ctypes.c_uint32, 1), + ("hasDoubles", ctypes.c_uint32, 1), + ("hasWarpVote", ctypes.c_uint32, 1), + ("hasWarpBallot", ctypes.c_uint32, 1), + ("hasWarpShuffle", ctypes.c_uint32, 1), + ("hasFunnelShift", ctypes.c_uint32, 1), + ("hasThreadFenceSystem", ctypes.c_uint32, 1), + ("hasSyncThreadsExt", ctypes.c_uint32, 1), + ("hasSurfaceFuncs", ctypes.c_uint32, 1), + ("has3dGrid", ctypes.c_uint32, 1), + ("hasDynamicParallelism", ctypes.c_uint32, 1), + ("PADDING_0", ctypes.c_uint16, 15), +] + +hipDeviceArch_t = struct_c__SA_hipDeviceArch_t + + +class struct_hipUUID_t(Structure): + pass + + +struct_hipUUID_t._pack_ = 1 # source:False +struct_hipUUID_t._fields_ = [ + ("bytes", ctypes.c_char * 16), +] + +hipUUID = struct_hipUUID_t + + +class struct_hipDeviceProp_tR0600(Structure): + pass + + +struct_hipDeviceProp_tR0600._pack_ = 1 # source:False +struct_hipDeviceProp_tR0600._fields_ = [ + ("name", ctypes.c_char * 256), + ("uuid", hipUUID), + ("luid", ctypes.c_char * 8), + ("luidDeviceNodeMask", ctypes.c_uint32), + ("PADDING_0", ctypes.c_ubyte * 4), + ("totalGlobalMem", ctypes.c_uint64), + ("sharedMemPerBlock", ctypes.c_uint64), + ("regsPerBlock", ctypes.c_int32), + ("warpSize", ctypes.c_int32), + ("memPitch", ctypes.c_uint64), + ("maxThreadsPerBlock", ctypes.c_int32), + ("maxThreadsDim", ctypes.c_int32 * 3), + ("maxGridSize", ctypes.c_int32 * 3), + ("clockRate", ctypes.c_int32), + ("totalConstMem", ctypes.c_uint64), + ("major", ctypes.c_int32), + ("minor", ctypes.c_int32), + ("textureAlignment", ctypes.c_uint64), + ("texturePitchAlignment", ctypes.c_uint64), + ("deviceOverlap", ctypes.c_int32), + ("multiProcessorCount", ctypes.c_int32), + ("kernelExecTimeoutEnabled", ctypes.c_int32), + ("integrated", ctypes.c_int32), + ("canMapHostMemory", ctypes.c_int32), + ("computeMode", ctypes.c_int32), + ("maxTexture1D", ctypes.c_int32), + ("maxTexture1DMipmap", ctypes.c_int32), + ("maxTexture1DLinear", ctypes.c_int32), + ("maxTexture2D", ctypes.c_int32 * 2), + ("maxTexture2DMipmap", ctypes.c_int32 * 2), + ("maxTexture2DLinear", ctypes.c_int32 * 3), + ("maxTexture2DGather", ctypes.c_int32 * 2), + ("maxTexture3D", ctypes.c_int32 * 3), + ("maxTexture3DAlt", ctypes.c_int32 * 3), + ("maxTextureCubemap", ctypes.c_int32), + ("maxTexture1DLayered", ctypes.c_int32 * 2), + ("maxTexture2DLayered", ctypes.c_int32 * 3), + ("maxTextureCubemapLayered", ctypes.c_int32 * 2), + ("maxSurface1D", ctypes.c_int32), + ("maxSurface2D", ctypes.c_int32 * 2), + ("maxSurface3D", ctypes.c_int32 * 3), + ("maxSurface1DLayered", ctypes.c_int32 * 2), + ("maxSurface2DLayered", ctypes.c_int32 * 3), + ("maxSurfaceCubemap", ctypes.c_int32), + ("maxSurfaceCubemapLayered", ctypes.c_int32 * 2), + ("surfaceAlignment", ctypes.c_uint64), + ("concurrentKernels", ctypes.c_int32), + ("ECCEnabled", ctypes.c_int32), + ("pciBusID", ctypes.c_int32), + ("pciDeviceID", ctypes.c_int32), + ("pciDomainID", ctypes.c_int32), + ("tccDriver", ctypes.c_int32), + ("asyncEngineCount", ctypes.c_int32), + ("unifiedAddressing", ctypes.c_int32), + ("memoryClockRate", ctypes.c_int32), + ("memoryBusWidth", ctypes.c_int32), + ("l2CacheSize", ctypes.c_int32), + ("persistingL2CacheMaxSize", ctypes.c_int32), + ("maxThreadsPerMultiProcessor", ctypes.c_int32), + ("streamPrioritiesSupported", ctypes.c_int32), + ("globalL1CacheSupported", ctypes.c_int32), + ("localL1CacheSupported", ctypes.c_int32), + ("sharedMemPerMultiprocessor", ctypes.c_uint64), + ("regsPerMultiprocessor", ctypes.c_int32), + ("managedMemory", ctypes.c_int32), + ("isMultiGpuBoard", ctypes.c_int32), + ("multiGpuBoardGroupID", ctypes.c_int32), + ("hostNativeAtomicSupported", ctypes.c_int32), + ("singleToDoublePrecisionPerfRatio", ctypes.c_int32), + ("pageableMemoryAccess", ctypes.c_int32), + ("concurrentManagedAccess", ctypes.c_int32), + ("computePreemptionSupported", ctypes.c_int32), + ("canUseHostPointerForRegisteredMem", ctypes.c_int32), + ("cooperativeLaunch", ctypes.c_int32), + ("cooperativeMultiDeviceLaunch", ctypes.c_int32), + ("sharedMemPerBlockOptin", ctypes.c_uint64), + ("pageableMemoryAccessUsesHostPageTables", ctypes.c_int32), + ("directManagedMemAccessFromHost", ctypes.c_int32), + ("maxBlocksPerMultiProcessor", ctypes.c_int32), + ("accessPolicyMaxWindowSize", ctypes.c_int32), + ("reservedSharedMemPerBlock", ctypes.c_uint64), + ("hostRegisterSupported", ctypes.c_int32), + ("sparseHipArraySupported", ctypes.c_int32), + ("hostRegisterReadOnlySupported", ctypes.c_int32), + ("timelineSemaphoreInteropSupported", ctypes.c_int32), + ("memoryPoolsSupported", ctypes.c_int32), + ("gpuDirectRDMASupported", ctypes.c_int32), + ("gpuDirectRDMAFlushWritesOptions", ctypes.c_uint32), + ("gpuDirectRDMAWritesOrdering", ctypes.c_int32), + ("memoryPoolSupportedHandleTypes", ctypes.c_uint32), + ("deferredMappingHipArraySupported", ctypes.c_int32), + ("ipcEventSupported", ctypes.c_int32), + ("clusterLaunch", ctypes.c_int32), + ("unifiedFunctionPointers", ctypes.c_int32), + ("reserved", ctypes.c_int32 * 63), + ("hipReserved", ctypes.c_int32 * 32), + ("gcnArchName", ctypes.c_char * 256), + ("maxSharedMemoryPerMultiProcessor", ctypes.c_uint64), + ("clockInstructionRate", ctypes.c_int32), + ("arch", hipDeviceArch_t), + ("hdpMemFlushCntl", ctypes.POINTER(ctypes.c_uint32)), + ("hdpRegFlushCntl", ctypes.POINTER(ctypes.c_uint32)), + ("cooperativeMultiDeviceUnmatchedFunc", ctypes.c_int32), + ("cooperativeMultiDeviceUnmatchedGridDim", ctypes.c_int32), + ("cooperativeMultiDeviceUnmatchedBlockDim", ctypes.c_int32), + ("cooperativeMultiDeviceUnmatchedSharedMem", ctypes.c_int32), + ("isLargeBar", ctypes.c_int32), + ("asicRevision", ctypes.c_int32), +] + +hipDeviceProp_tR0600 = struct_hipDeviceProp_tR0600 + +hipMemoryType__enumvalues = { + 0: "hipMemoryTypeUnregistered", + 1: "hipMemoryTypeHost", + 2: "hipMemoryTypeDevice", + 3: "hipMemoryTypeManaged", + 10: "hipMemoryTypeArray", + 11: "hipMemoryTypeUnified", +} +hipMemoryTypeUnregistered = 0 +hipMemoryTypeHost = 1 +hipMemoryTypeDevice = 2 +hipMemoryTypeManaged = 3 +hipMemoryTypeArray = 10 +hipMemoryTypeUnified = 11 +hipMemoryType = ctypes.c_uint32 # enum + + +class struct_hipPointerAttribute_t(Structure): + pass + + +struct_hipPointerAttribute_t._pack_ = 1 # source:False +struct_hipPointerAttribute_t._fields_ = [ + ("type", hipMemoryType), + ("device", ctypes.c_int32), + ("devicePointer", ctypes.POINTER(None)), + ("hostPointer", ctypes.POINTER(None)), + ("isManaged", ctypes.c_int32), + ("allocationFlags", ctypes.c_uint32), +] + +hipPointerAttribute_t = struct_hipPointerAttribute_t + +hipError_t__enumvalues = { + 0: "hipSuccess", + 1: "hipErrorInvalidValue", + 2: "hipErrorOutOfMemory", + 2: "hipErrorMemoryAllocation", + 3: "hipErrorNotInitialized", + 3: "hipErrorInitializationError", + 4: "hipErrorDeinitialized", + 5: "hipErrorProfilerDisabled", + 6: "hipErrorProfilerNotInitialized", + 7: "hipErrorProfilerAlreadyStarted", + 8: "hipErrorProfilerAlreadyStopped", + 9: "hipErrorInvalidConfiguration", + 12: "hipErrorInvalidPitchValue", + 13: "hipErrorInvalidSymbol", + 17: "hipErrorInvalidDevicePointer", + 21: "hipErrorInvalidMemcpyDirection", + 35: "hipErrorInsufficientDriver", + 52: "hipErrorMissingConfiguration", + 53: "hipErrorPriorLaunchFailure", + 98: "hipErrorInvalidDeviceFunction", + 100: "hipErrorNoDevice", + 101: "hipErrorInvalidDevice", + 200: "hipErrorInvalidImage", + 201: "hipErrorInvalidContext", + 202: "hipErrorContextAlreadyCurrent", + 205: "hipErrorMapFailed", + 205: "hipErrorMapBufferObjectFailed", + 206: "hipErrorUnmapFailed", + 207: "hipErrorArrayIsMapped", + 208: "hipErrorAlreadyMapped", + 209: "hipErrorNoBinaryForGpu", + 210: "hipErrorAlreadyAcquired", + 211: "hipErrorNotMapped", + 212: "hipErrorNotMappedAsArray", + 213: "hipErrorNotMappedAsPointer", + 214: "hipErrorECCNotCorrectable", + 215: "hipErrorUnsupportedLimit", + 216: "hipErrorContextAlreadyInUse", + 217: "hipErrorPeerAccessUnsupported", + 218: "hipErrorInvalidKernelFile", + 219: "hipErrorInvalidGraphicsContext", + 300: "hipErrorInvalidSource", + 301: "hipErrorFileNotFound", + 302: "hipErrorSharedObjectSymbolNotFound", + 303: "hipErrorSharedObjectInitFailed", + 304: "hipErrorOperatingSystem", + 400: "hipErrorInvalidHandle", + 400: "hipErrorInvalidResourceHandle", + 401: "hipErrorIllegalState", + 500: "hipErrorNotFound", + 600: "hipErrorNotReady", + 700: "hipErrorIllegalAddress", + 701: "hipErrorLaunchOutOfResources", + 702: "hipErrorLaunchTimeOut", + 704: "hipErrorPeerAccessAlreadyEnabled", + 705: "hipErrorPeerAccessNotEnabled", + 708: "hipErrorSetOnActiveProcess", + 709: "hipErrorContextIsDestroyed", + 710: "hipErrorAssert", + 712: "hipErrorHostMemoryAlreadyRegistered", + 713: "hipErrorHostMemoryNotRegistered", + 719: "hipErrorLaunchFailure", + 720: "hipErrorCooperativeLaunchTooLarge", + 801: "hipErrorNotSupported", + 900: "hipErrorStreamCaptureUnsupported", + 901: "hipErrorStreamCaptureInvalidated", + 902: "hipErrorStreamCaptureMerge", + 903: "hipErrorStreamCaptureUnmatched", + 904: "hipErrorStreamCaptureUnjoined", + 905: "hipErrorStreamCaptureIsolation", + 906: "hipErrorStreamCaptureImplicit", + 907: "hipErrorCapturedEvent", + 908: "hipErrorStreamCaptureWrongThread", + 910: "hipErrorGraphExecUpdateFailure", + 999: "hipErrorUnknown", + 1052: "hipErrorRuntimeMemory", + 1053: "hipErrorRuntimeOther", + 1054: "hipErrorTbd", +} +hipSuccess = 0 +hipErrorInvalidValue = 1 +hipErrorOutOfMemory = 2 +hipErrorMemoryAllocation = 2 +hipErrorNotInitialized = 3 +hipErrorInitializationError = 3 +hipErrorDeinitialized = 4 +hipErrorProfilerDisabled = 5 +hipErrorProfilerNotInitialized = 6 +hipErrorProfilerAlreadyStarted = 7 +hipErrorProfilerAlreadyStopped = 8 +hipErrorInvalidConfiguration = 9 +hipErrorInvalidPitchValue = 12 +hipErrorInvalidSymbol = 13 +hipErrorInvalidDevicePointer = 17 +hipErrorInvalidMemcpyDirection = 21 +hipErrorInsufficientDriver = 35 +hipErrorMissingConfiguration = 52 +hipErrorPriorLaunchFailure = 53 +hipErrorInvalidDeviceFunction = 98 +hipErrorNoDevice = 100 +hipErrorInvalidDevice = 101 +hipErrorInvalidImage = 200 +hipErrorInvalidContext = 201 +hipErrorContextAlreadyCurrent = 202 +hipErrorMapFailed = 205 +hipErrorMapBufferObjectFailed = 205 +hipErrorUnmapFailed = 206 +hipErrorArrayIsMapped = 207 +hipErrorAlreadyMapped = 208 +hipErrorNoBinaryForGpu = 209 +hipErrorAlreadyAcquired = 210 +hipErrorNotMapped = 211 +hipErrorNotMappedAsArray = 212 +hipErrorNotMappedAsPointer = 213 +hipErrorECCNotCorrectable = 214 +hipErrorUnsupportedLimit = 215 +hipErrorContextAlreadyInUse = 216 +hipErrorPeerAccessUnsupported = 217 +hipErrorInvalidKernelFile = 218 +hipErrorInvalidGraphicsContext = 219 +hipErrorInvalidSource = 300 +hipErrorFileNotFound = 301 +hipErrorSharedObjectSymbolNotFound = 302 +hipErrorSharedObjectInitFailed = 303 +hipErrorOperatingSystem = 304 +hipErrorInvalidHandle = 400 +hipErrorInvalidResourceHandle = 400 +hipErrorIllegalState = 401 +hipErrorNotFound = 500 +hipErrorNotReady = 600 +hipErrorIllegalAddress = 700 +hipErrorLaunchOutOfResources = 701 +hipErrorLaunchTimeOut = 702 +hipErrorPeerAccessAlreadyEnabled = 704 +hipErrorPeerAccessNotEnabled = 705 +hipErrorSetOnActiveProcess = 708 +hipErrorContextIsDestroyed = 709 +hipErrorAssert = 710 +hipErrorHostMemoryAlreadyRegistered = 712 +hipErrorHostMemoryNotRegistered = 713 +hipErrorLaunchFailure = 719 +hipErrorCooperativeLaunchTooLarge = 720 +hipErrorNotSupported = 801 +hipErrorStreamCaptureUnsupported = 900 +hipErrorStreamCaptureInvalidated = 901 +hipErrorStreamCaptureMerge = 902 +hipErrorStreamCaptureUnmatched = 903 +hipErrorStreamCaptureUnjoined = 904 +hipErrorStreamCaptureIsolation = 905 +hipErrorStreamCaptureImplicit = 906 +hipErrorCapturedEvent = 907 +hipErrorStreamCaptureWrongThread = 908 +hipErrorGraphExecUpdateFailure = 910 +hipErrorUnknown = 999 +hipErrorRuntimeMemory = 1052 +hipErrorRuntimeOther = 1053 +hipErrorTbd = 1054 +hipError_t = ctypes.c_uint32 # enum + +hipDeviceAttribute_t__enumvalues = { + 0: "hipDeviceAttributeCudaCompatibleBegin", + 0: "hipDeviceAttributeEccEnabled", + 1: "hipDeviceAttributeAccessPolicyMaxWindowSize", + 2: "hipDeviceAttributeAsyncEngineCount", + 3: "hipDeviceAttributeCanMapHostMemory", + 4: "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + 5: "hipDeviceAttributeClockRate", + 6: "hipDeviceAttributeComputeMode", + 7: "hipDeviceAttributeComputePreemptionSupported", + 8: "hipDeviceAttributeConcurrentKernels", + 9: "hipDeviceAttributeConcurrentManagedAccess", + 10: "hipDeviceAttributeCooperativeLaunch", + 11: "hipDeviceAttributeCooperativeMultiDeviceLaunch", + 12: "hipDeviceAttributeDeviceOverlap", + 13: "hipDeviceAttributeDirectManagedMemAccessFromHost", + 14: "hipDeviceAttributeGlobalL1CacheSupported", + 15: "hipDeviceAttributeHostNativeAtomicSupported", + 16: "hipDeviceAttributeIntegrated", + 17: "hipDeviceAttributeIsMultiGpuBoard", + 18: "hipDeviceAttributeKernelExecTimeout", + 19: "hipDeviceAttributeL2CacheSize", + 20: "hipDeviceAttributeLocalL1CacheSupported", + 21: "hipDeviceAttributeLuid", + 22: "hipDeviceAttributeLuidDeviceNodeMask", + 23: "hipDeviceAttributeComputeCapabilityMajor", + 24: "hipDeviceAttributeManagedMemory", + 25: "hipDeviceAttributeMaxBlocksPerMultiProcessor", + 26: "hipDeviceAttributeMaxBlockDimX", + 27: "hipDeviceAttributeMaxBlockDimY", + 28: "hipDeviceAttributeMaxBlockDimZ", + 29: "hipDeviceAttributeMaxGridDimX", + 30: "hipDeviceAttributeMaxGridDimY", + 31: "hipDeviceAttributeMaxGridDimZ", + 32: "hipDeviceAttributeMaxSurface1D", + 33: "hipDeviceAttributeMaxSurface1DLayered", + 34: "hipDeviceAttributeMaxSurface2D", + 35: "hipDeviceAttributeMaxSurface2DLayered", + 36: "hipDeviceAttributeMaxSurface3D", + 37: "hipDeviceAttributeMaxSurfaceCubemap", + 38: "hipDeviceAttributeMaxSurfaceCubemapLayered", + 39: "hipDeviceAttributeMaxTexture1DWidth", + 40: "hipDeviceAttributeMaxTexture1DLayered", + 41: "hipDeviceAttributeMaxTexture1DLinear", + 42: "hipDeviceAttributeMaxTexture1DMipmap", + 43: "hipDeviceAttributeMaxTexture2DWidth", + 44: "hipDeviceAttributeMaxTexture2DHeight", + 45: "hipDeviceAttributeMaxTexture2DGather", + 46: "hipDeviceAttributeMaxTexture2DLayered", + 47: "hipDeviceAttributeMaxTexture2DLinear", + 48: "hipDeviceAttributeMaxTexture2DMipmap", + 49: "hipDeviceAttributeMaxTexture3DWidth", + 50: "hipDeviceAttributeMaxTexture3DHeight", + 51: "hipDeviceAttributeMaxTexture3DDepth", + 52: "hipDeviceAttributeMaxTexture3DAlt", + 53: "hipDeviceAttributeMaxTextureCubemap", + 54: "hipDeviceAttributeMaxTextureCubemapLayered", + 55: "hipDeviceAttributeMaxThreadsDim", + 56: "hipDeviceAttributeMaxThreadsPerBlock", + 57: "hipDeviceAttributeMaxThreadsPerMultiProcessor", + 58: "hipDeviceAttributeMaxPitch", + 59: "hipDeviceAttributeMemoryBusWidth", + 60: "hipDeviceAttributeMemoryClockRate", + 61: "hipDeviceAttributeComputeCapabilityMinor", + 62: "hipDeviceAttributeMultiGpuBoardGroupID", + 63: "hipDeviceAttributeMultiprocessorCount", + 64: "hipDeviceAttributeUnused1", + 65: "hipDeviceAttributePageableMemoryAccess", + 66: "hipDeviceAttributePageableMemoryAccessUsesHostPageTables", + 67: "hipDeviceAttributePciBusId", + 68: "hipDeviceAttributePciDeviceId", + 69: "hipDeviceAttributePciDomainID", + 70: "hipDeviceAttributePersistingL2CacheMaxSize", + 71: "hipDeviceAttributeMaxRegistersPerBlock", + 72: "hipDeviceAttributeMaxRegistersPerMultiprocessor", + 73: "hipDeviceAttributeReservedSharedMemPerBlock", + 74: "hipDeviceAttributeMaxSharedMemoryPerBlock", + 75: "hipDeviceAttributeSharedMemPerBlockOptin", + 76: "hipDeviceAttributeSharedMemPerMultiprocessor", + 77: "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + 78: "hipDeviceAttributeStreamPrioritiesSupported", + 79: "hipDeviceAttributeSurfaceAlignment", + 80: "hipDeviceAttributeTccDriver", + 81: "hipDeviceAttributeTextureAlignment", + 82: "hipDeviceAttributeTexturePitchAlignment", + 83: "hipDeviceAttributeTotalConstantMemory", + 84: "hipDeviceAttributeTotalGlobalMem", + 85: "hipDeviceAttributeUnifiedAddressing", + 86: "hipDeviceAttributeUnused2", + 87: "hipDeviceAttributeWarpSize", + 88: "hipDeviceAttributeMemoryPoolsSupported", + 89: "hipDeviceAttributeVirtualMemoryManagementSupported", + 90: "hipDeviceAttributeHostRegisterSupported", + 9999: "hipDeviceAttributeCudaCompatibleEnd", + 10000: "hipDeviceAttributeAmdSpecificBegin", + 10000: "hipDeviceAttributeClockInstructionRate", + 10001: "hipDeviceAttributeUnused3", + 10002: "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + 10003: "hipDeviceAttributeUnused4", + 10004: "hipDeviceAttributeUnused5", + 10005: "hipDeviceAttributeHdpMemFlushCntl", + 10006: "hipDeviceAttributeHdpRegFlushCntl", + 10007: "hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc", + 10008: "hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim", + 10009: "hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim", + 10010: "hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem", + 10011: "hipDeviceAttributeIsLargeBar", + 10012: "hipDeviceAttributeAsicRevision", + 10013: "hipDeviceAttributeCanUseStreamWaitValue", + 10014: "hipDeviceAttributeImageSupport", + 10015: "hipDeviceAttributePhysicalMultiProcessorCount", + 10016: "hipDeviceAttributeFineGrainSupport", + 10017: "hipDeviceAttributeWallClockRate", + 19999: "hipDeviceAttributeAmdSpecificEnd", + 20000: "hipDeviceAttributeVendorSpecificBegin", +} +hipDeviceAttributeCudaCompatibleBegin = 0 +hipDeviceAttributeEccEnabled = 0 +hipDeviceAttributeAccessPolicyMaxWindowSize = 1 +hipDeviceAttributeAsyncEngineCount = 2 +hipDeviceAttributeCanMapHostMemory = 3 +hipDeviceAttributeCanUseHostPointerForRegisteredMem = 4 +hipDeviceAttributeClockRate = 5 +hipDeviceAttributeComputeMode = 6 +hipDeviceAttributeComputePreemptionSupported = 7 +hipDeviceAttributeConcurrentKernels = 8 +hipDeviceAttributeConcurrentManagedAccess = 9 +hipDeviceAttributeCooperativeLaunch = 10 +hipDeviceAttributeCooperativeMultiDeviceLaunch = 11 +hipDeviceAttributeDeviceOverlap = 12 +hipDeviceAttributeDirectManagedMemAccessFromHost = 13 +hipDeviceAttributeGlobalL1CacheSupported = 14 +hipDeviceAttributeHostNativeAtomicSupported = 15 +hipDeviceAttributeIntegrated = 16 +hipDeviceAttributeIsMultiGpuBoard = 17 +hipDeviceAttributeKernelExecTimeout = 18 +hipDeviceAttributeL2CacheSize = 19 +hipDeviceAttributeLocalL1CacheSupported = 20 +hipDeviceAttributeLuid = 21 +hipDeviceAttributeLuidDeviceNodeMask = 22 +hipDeviceAttributeComputeCapabilityMajor = 23 +hipDeviceAttributeManagedMemory = 24 +hipDeviceAttributeMaxBlocksPerMultiProcessor = 25 +hipDeviceAttributeMaxBlockDimX = 26 +hipDeviceAttributeMaxBlockDimY = 27 +hipDeviceAttributeMaxBlockDimZ = 28 +hipDeviceAttributeMaxGridDimX = 29 +hipDeviceAttributeMaxGridDimY = 30 +hipDeviceAttributeMaxGridDimZ = 31 +hipDeviceAttributeMaxSurface1D = 32 +hipDeviceAttributeMaxSurface1DLayered = 33 +hipDeviceAttributeMaxSurface2D = 34 +hipDeviceAttributeMaxSurface2DLayered = 35 +hipDeviceAttributeMaxSurface3D = 36 +hipDeviceAttributeMaxSurfaceCubemap = 37 +hipDeviceAttributeMaxSurfaceCubemapLayered = 38 +hipDeviceAttributeMaxTexture1DWidth = 39 +hipDeviceAttributeMaxTexture1DLayered = 40 +hipDeviceAttributeMaxTexture1DLinear = 41 +hipDeviceAttributeMaxTexture1DMipmap = 42 +hipDeviceAttributeMaxTexture2DWidth = 43 +hipDeviceAttributeMaxTexture2DHeight = 44 +hipDeviceAttributeMaxTexture2DGather = 45 +hipDeviceAttributeMaxTexture2DLayered = 46 +hipDeviceAttributeMaxTexture2DLinear = 47 +hipDeviceAttributeMaxTexture2DMipmap = 48 +hipDeviceAttributeMaxTexture3DWidth = 49 +hipDeviceAttributeMaxTexture3DHeight = 50 +hipDeviceAttributeMaxTexture3DDepth = 51 +hipDeviceAttributeMaxTexture3DAlt = 52 +hipDeviceAttributeMaxTextureCubemap = 53 +hipDeviceAttributeMaxTextureCubemapLayered = 54 +hipDeviceAttributeMaxThreadsDim = 55 +hipDeviceAttributeMaxThreadsPerBlock = 56 +hipDeviceAttributeMaxThreadsPerMultiProcessor = 57 +hipDeviceAttributeMaxPitch = 58 +hipDeviceAttributeMemoryBusWidth = 59 +hipDeviceAttributeMemoryClockRate = 60 +hipDeviceAttributeComputeCapabilityMinor = 61 +hipDeviceAttributeMultiGpuBoardGroupID = 62 +hipDeviceAttributeMultiprocessorCount = 63 +hipDeviceAttributeUnused1 = 64 +hipDeviceAttributePageableMemoryAccess = 65 +hipDeviceAttributePageableMemoryAccessUsesHostPageTables = 66 +hipDeviceAttributePciBusId = 67 +hipDeviceAttributePciDeviceId = 68 +hipDeviceAttributePciDomainID = 69 +hipDeviceAttributePersistingL2CacheMaxSize = 70 +hipDeviceAttributeMaxRegistersPerBlock = 71 +hipDeviceAttributeMaxRegistersPerMultiprocessor = 72 +hipDeviceAttributeReservedSharedMemPerBlock = 73 +hipDeviceAttributeMaxSharedMemoryPerBlock = 74 +hipDeviceAttributeSharedMemPerBlockOptin = 75 +hipDeviceAttributeSharedMemPerMultiprocessor = 76 +hipDeviceAttributeSingleToDoublePrecisionPerfRatio = 77 +hipDeviceAttributeStreamPrioritiesSupported = 78 +hipDeviceAttributeSurfaceAlignment = 79 +hipDeviceAttributeTccDriver = 80 +hipDeviceAttributeTextureAlignment = 81 +hipDeviceAttributeTexturePitchAlignment = 82 +hipDeviceAttributeTotalConstantMemory = 83 +hipDeviceAttributeTotalGlobalMem = 84 +hipDeviceAttributeUnifiedAddressing = 85 +hipDeviceAttributeUnused2 = 86 +hipDeviceAttributeWarpSize = 87 +hipDeviceAttributeMemoryPoolsSupported = 88 +hipDeviceAttributeVirtualMemoryManagementSupported = 89 +hipDeviceAttributeHostRegisterSupported = 90 +hipDeviceAttributeCudaCompatibleEnd = 9999 +hipDeviceAttributeAmdSpecificBegin = 10000 +hipDeviceAttributeClockInstructionRate = 10000 +hipDeviceAttributeUnused3 = 10001 +hipDeviceAttributeMaxSharedMemoryPerMultiprocessor = 10002 +hipDeviceAttributeUnused4 = 10003 +hipDeviceAttributeUnused5 = 10004 +hipDeviceAttributeHdpMemFlushCntl = 10005 +hipDeviceAttributeHdpRegFlushCntl = 10006 +hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc = 10007 +hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim = 10008 +hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim = 10009 +hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem = 10010 +hipDeviceAttributeIsLargeBar = 10011 +hipDeviceAttributeAsicRevision = 10012 +hipDeviceAttributeCanUseStreamWaitValue = 10013 +hipDeviceAttributeImageSupport = 10014 +hipDeviceAttributePhysicalMultiProcessorCount = 10015 +hipDeviceAttributeFineGrainSupport = 10016 +hipDeviceAttributeWallClockRate = 10017 +hipDeviceAttributeAmdSpecificEnd = 19999 +hipDeviceAttributeVendorSpecificBegin = 20000 +hipDeviceAttribute_t = ctypes.c_uint32 # enum +hipDeviceptr_t = ctypes.POINTER(None) +hipFunction_attribute__enumvalues = { + 0: "HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + 1: "HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES", + 2: "HIP_FUNC_ATTRIBUTE_CONST_SIZE_BYTES", + 3: "HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES", + 4: "HIP_FUNC_ATTRIBUTE_NUM_REGS", + 5: "HIP_FUNC_ATTRIBUTE_PTX_VERSION", + 6: "HIP_FUNC_ATTRIBUTE_BINARY_VERSION", + 7: "HIP_FUNC_ATTRIBUTE_CACHE_MODE_CA", + 8: "HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES", + 9: "HIP_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT", + 10: "HIP_FUNC_ATTRIBUTE_MAX", +} +HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 0 +HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES = 1 +HIP_FUNC_ATTRIBUTE_CONST_SIZE_BYTES = 2 +HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES = 3 +HIP_FUNC_ATTRIBUTE_NUM_REGS = 4 +HIP_FUNC_ATTRIBUTE_PTX_VERSION = 5 +HIP_FUNC_ATTRIBUTE_BINARY_VERSION = 6 +HIP_FUNC_ATTRIBUTE_CACHE_MODE_CA = 7 +HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8 +HIP_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT = 9 +HIP_FUNC_ATTRIBUTE_MAX = 10 +hipFunction_attribute = ctypes.c_uint32 # enum + + +class struct_ihipStream_t(Structure): + pass + + +hipStream_t = ctypes.POINTER(struct_ihipStream_t) + + +class struct_ihipModule_t(Structure): + pass + + +hipModule_t = ctypes.POINTER(struct_ihipModule_t) + + +class struct_ihipModuleSymbol_t(Structure): + pass + + +hipFunction_t = ctypes.POINTER(struct_ihipModuleSymbol_t) +hipJitOption__enumvalues = { + 0: "hipJitOptionMaxRegisters", + 1: "hipJitOptionThreadsPerBlock", + 2: "hipJitOptionWallTime", + 3: "hipJitOptionInfoLogBuffer", + 4: "hipJitOptionInfoLogBufferSizeBytes", + 5: "hipJitOptionErrorLogBuffer", + 6: "hipJitOptionErrorLogBufferSizeBytes", + 7: "hipJitOptionOptimizationLevel", + 8: "hipJitOptionTargetFromContext", + 9: "hipJitOptionTarget", + 10: "hipJitOptionFallbackStrategy", + 11: "hipJitOptionGenerateDebugInfo", + 12: "hipJitOptionLogVerbose", + 13: "hipJitOptionGenerateLineInfo", + 14: "hipJitOptionCacheMode", + 15: "hipJitOptionSm3xOpt", + 16: "hipJitOptionFastCompile", + 17: "hipJitOptionNumOptions", +} +hipJitOptionMaxRegisters = 0 +hipJitOptionThreadsPerBlock = 1 +hipJitOptionWallTime = 2 +hipJitOptionInfoLogBuffer = 3 +hipJitOptionInfoLogBufferSizeBytes = 4 +hipJitOptionErrorLogBuffer = 5 +hipJitOptionErrorLogBufferSizeBytes = 6 +hipJitOptionOptimizationLevel = 7 +hipJitOptionTargetFromContext = 8 +hipJitOptionTarget = 9 +hipJitOptionFallbackStrategy = 10 +hipJitOptionGenerateDebugInfo = 11 +hipJitOptionLogVerbose = 12 +hipJitOptionGenerateLineInfo = 13 +hipJitOptionCacheMode = 14 +hipJitOptionSm3xOpt = 15 +hipJitOptionFastCompile = 16 +hipJitOptionNumOptions = 17 +hipJitOption = ctypes.c_uint32 # enum + +hipGetDevicePropertiesR0600 = _libraries["libamdhip64.so"].hipGetDevicePropertiesR0600 +hipGetDevicePropertiesR0600.restype = hipError_t +hipGetDevicePropertiesR0600.argtypes = [ + ctypes.POINTER(struct_hipDeviceProp_tR0600), + ctypes.c_int32, +] + +hipPointerGetAttributes = _libraries["libamdhip64.so"].hipPointerGetAttributes +hipPointerGetAttributes.restype = hipError_t +hipPointerGetAttributes.argtypes = [ + ctypes.POINTER(struct_hipPointerAttribute_t), + ctypes.POINTER(None), +] + +hipModuleGetFunction = _libraries["libamdhip64.so"].hipModuleGetFunction +hipModuleGetFunction.restype = hipError_t +hipModuleGetFunction.argtypes = [ + ctypes.POINTER(ctypes.POINTER(struct_ihipModuleSymbol_t)), + hipModule_t, + ctypes.POINTER(ctypes.c_char), +] + +hipFuncGetAttribute = _libraries["libamdhip64.so"].hipFuncGetAttribute +hipFuncGetAttribute.restype = hipError_t +hipFuncGetAttribute.argtypes = [ + ctypes.POINTER(ctypes.c_int32), + hipFunction_attribute, + hipFunction_t, +] + +hipModuleLoadDataEx = _libraries["libamdhip64.so"].hipModuleLoadDataEx +hipModuleLoadDataEx.restype = hipError_t +hipModuleLoadDataEx.argtypes = [ + ctypes.POINTER(ctypes.POINTER(struct_ihipModule_t)), + ctypes.POINTER(None), + ctypes.c_uint32, + ctypes.POINTER(hipJitOption), + ctypes.POINTER(ctypes.POINTER(None)), +] + +hipModuleLaunchKernel = _libraries["libamdhip64.so"].hipModuleLaunchKernel +hipModuleLaunchKernel.restype = hipError_t +hipModuleLaunchKernel.argtypes = [ + hipFunction_t, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + hipStream_t, + ctypes.POINTER(ctypes.POINTER(None)), + ctypes.POINTER(ctypes.POINTER(None)), +] + +hipDeviceProp_t = hipDeviceProp_tR0600 +hipGetDeviceProperties = hipGetDevicePropertiesR0600 + +hipGetErrorString = _libraries["libamdhip64.so"].hipGetErrorString +hipGetErrorString.restype = ctypes.POINTER(ctypes.c_char) +hipGetErrorString.argtypes = [hipError_t] + +__all__ = [ + "HIP_ERROR_INVALID_VALUE", + "HIP_ERROR_LAUNCH_OUT_OF_RESOURCES", + "HIP_ERROR_NOT_INITIALIZED", + "HIP_FUNC_ATTRIBUTE_BINARY_VERSION", + "HIP_FUNC_ATTRIBUTE_CACHE_MODE_CA", + "HIP_FUNC_ATTRIBUTE_CONST_SIZE_BYTES", + "HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES", + "HIP_FUNC_ATTRIBUTE_MAX", + "HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES", + "HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + "HIP_FUNC_ATTRIBUTE_NUM_REGS", + "HIP_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT", + "HIP_FUNC_ATTRIBUTE_PTX_VERSION", + "HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES", + "HIP_SUCCESS", + "hipDeviceArch_t", + "hipDeviceAttributeAccessPolicyMaxWindowSize", + "hipDeviceAttributeAmdSpecificBegin", + "hipDeviceAttributeAmdSpecificEnd", + "hipDeviceAttributeAsicRevision", + "hipDeviceAttributeAsyncEngineCount", + "hipDeviceAttributeCanMapHostMemory", + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + "hipDeviceAttributeCanUseStreamWaitValue", + "hipDeviceAttributeClockInstructionRate", + "hipDeviceAttributeClockRate", + "hipDeviceAttributeComputeCapabilityMajor", + "hipDeviceAttributeComputeCapabilityMinor", + "hipDeviceAttributeComputeMode", + "hipDeviceAttributeComputePreemptionSupported", + "hipDeviceAttributeConcurrentKernels", + "hipDeviceAttributeConcurrentManagedAccess", + "hipDeviceAttributeCooperativeLaunch", + "hipDeviceAttributeCooperativeMultiDeviceLaunch", + "hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim", + "hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc", + "hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim", + "hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem", + "hipDeviceAttributeCudaCompatibleBegin", + "hipDeviceAttributeCudaCompatibleEnd", + "hipDeviceAttributeDeviceOverlap", + "hipDeviceAttributeDirectManagedMemAccessFromHost", + "hipDeviceAttributeEccEnabled", + "hipDeviceAttributeFineGrainSupport", + "hipDeviceAttributeGlobalL1CacheSupported", + "hipDeviceAttributeHdpMemFlushCntl", + "hipDeviceAttributeHdpRegFlushCntl", + "hipDeviceAttributeHostNativeAtomicSupported", + "hipDeviceAttributeHostRegisterSupported", + "hipDeviceAttributeImageSupport", + "hipDeviceAttributeIntegrated", + "hipDeviceAttributeIsLargeBar", + "hipDeviceAttributeIsMultiGpuBoard", + "hipDeviceAttributeKernelExecTimeout", + "hipDeviceAttributeL2CacheSize", + "hipDeviceAttributeLocalL1CacheSupported", + "hipDeviceAttributeLuid", + "hipDeviceAttributeLuidDeviceNodeMask", + "hipDeviceAttributeManagedMemory", + "hipDeviceAttributeMaxBlockDimX", + "hipDeviceAttributeMaxBlockDimY", + "hipDeviceAttributeMaxBlockDimZ", + "hipDeviceAttributeMaxBlocksPerMultiProcessor", + "hipDeviceAttributeMaxGridDimX", + "hipDeviceAttributeMaxGridDimY", + "hipDeviceAttributeMaxGridDimZ", + "hipDeviceAttributeMaxPitch", + "hipDeviceAttributeMaxRegistersPerBlock", + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + "hipDeviceAttributeMaxSharedMemoryPerBlock", + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + "hipDeviceAttributeMaxSurface1D", + "hipDeviceAttributeMaxSurface1DLayered", + "hipDeviceAttributeMaxSurface2D", + "hipDeviceAttributeMaxSurface2DLayered", + "hipDeviceAttributeMaxSurface3D", + "hipDeviceAttributeMaxSurfaceCubemap", + "hipDeviceAttributeMaxSurfaceCubemapLayered", + "hipDeviceAttributeMaxTexture1DLayered", + "hipDeviceAttributeMaxTexture1DLinear", + "hipDeviceAttributeMaxTexture1DMipmap", + "hipDeviceAttributeMaxTexture1DWidth", + "hipDeviceAttributeMaxTexture2DGather", + "hipDeviceAttributeMaxTexture2DHeight", + "hipDeviceAttributeMaxTexture2DLayered", + "hipDeviceAttributeMaxTexture2DLinear", + "hipDeviceAttributeMaxTexture2DMipmap", + "hipDeviceAttributeMaxTexture2DWidth", + "hipDeviceAttributeMaxTexture3DAlt", + "hipDeviceAttributeMaxTexture3DDepth", + "hipDeviceAttributeMaxTexture3DHeight", + "hipDeviceAttributeMaxTexture3DWidth", + "hipDeviceAttributeMaxTextureCubemap", + "hipDeviceAttributeMaxTextureCubemapLayered", + "hipDeviceAttributeMaxThreadsDim", + "hipDeviceAttributeMaxThreadsPerBlock", + "hipDeviceAttributeMaxThreadsPerMultiProcessor", + "hipDeviceAttributeMemoryBusWidth", + "hipDeviceAttributeMemoryClockRate", + "hipDeviceAttributeMemoryPoolsSupported", + "hipDeviceAttributeMultiGpuBoardGroupID", + "hipDeviceAttributeMultiprocessorCount", + "hipDeviceAttributePageableMemoryAccess", + "hipDeviceAttributePageableMemoryAccessUsesHostPageTables", + "hipDeviceAttributePciBusId", + "hipDeviceAttributePciDeviceId", + "hipDeviceAttributePciDomainID", + "hipDeviceAttributePersistingL2CacheMaxSize", + "hipDeviceAttributePhysicalMultiProcessorCount", + "hipDeviceAttributeReservedSharedMemPerBlock", + "hipDeviceAttributeSharedMemPerBlockOptin", + "hipDeviceAttributeSharedMemPerMultiprocessor", + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + "hipDeviceAttributeStreamPrioritiesSupported", + "hipDeviceAttributeSurfaceAlignment", + "hipDeviceAttributeTccDriver", + "hipDeviceAttributeTextureAlignment", + "hipDeviceAttributeTexturePitchAlignment", + "hipDeviceAttributeTotalConstantMemory", + "hipDeviceAttributeTotalGlobalMem", + "hipDeviceAttributeUnifiedAddressing", + "hipDeviceAttributeUnused1", + "hipDeviceAttributeUnused2", + "hipDeviceAttributeUnused3", + "hipDeviceAttributeUnused4", + "hipDeviceAttributeUnused5", + "hipDeviceAttributeVendorSpecificBegin", + "hipDeviceAttributeVirtualMemoryManagementSupported", + "hipDeviceAttributeWallClockRate", + "hipDeviceAttributeWarpSize", + "hipDeviceAttribute_t", + "hipDeviceProp_tR0600", + "hipDeviceptr_t", + "hipErrorAlreadyAcquired", + "hipErrorAlreadyMapped", + "hipErrorArrayIsMapped", + "hipErrorAssert", + "hipErrorCapturedEvent", + "hipErrorContextAlreadyCurrent", + "hipErrorContextAlreadyInUse", + "hipErrorContextIsDestroyed", + "hipErrorCooperativeLaunchTooLarge", + "hipErrorDeinitialized", + "hipErrorECCNotCorrectable", + "hipErrorFileNotFound", + "hipErrorGraphExecUpdateFailure", + "hipErrorHostMemoryAlreadyRegistered", + "hipErrorHostMemoryNotRegistered", + "hipErrorIllegalAddress", + "hipErrorIllegalState", + "hipErrorInitializationError", + "hipErrorInsufficientDriver", + "hipErrorInvalidConfiguration", + "hipErrorInvalidContext", + "hipErrorInvalidDevice", + "hipErrorInvalidDeviceFunction", + "hipErrorInvalidDevicePointer", + "hipErrorInvalidGraphicsContext", + "hipErrorInvalidHandle", + "hipErrorInvalidImage", + "hipErrorInvalidKernelFile", + "hipErrorInvalidMemcpyDirection", + "hipErrorInvalidPitchValue", + "hipErrorInvalidResourceHandle", + "hipErrorInvalidSource", + "hipErrorInvalidSymbol", + "hipErrorInvalidValue", + "hipErrorLaunchFailure", + "hipErrorLaunchOutOfResources", + "hipErrorLaunchTimeOut", + "hipErrorMapBufferObjectFailed", + "hipErrorMapFailed", + "hipErrorMemoryAllocation", + "hipErrorMissingConfiguration", + "hipErrorNoBinaryForGpu", + "hipErrorNoDevice", + "hipErrorNotFound", + "hipErrorNotInitialized", + "hipErrorNotMapped", + "hipErrorNotMappedAsArray", + "hipErrorNotMappedAsPointer", + "hipErrorNotReady", + "hipErrorNotSupported", + "hipErrorOperatingSystem", + "hipErrorOutOfMemory", + "hipErrorPeerAccessAlreadyEnabled", + "hipErrorPeerAccessNotEnabled", + "hipErrorPeerAccessUnsupported", + "hipErrorPriorLaunchFailure", + "hipErrorProfilerAlreadyStarted", + "hipErrorProfilerAlreadyStopped", + "hipErrorProfilerDisabled", + "hipErrorProfilerNotInitialized", + "hipErrorRuntimeMemory", + "hipErrorRuntimeOther", + "hipErrorSetOnActiveProcess", + "hipErrorSharedObjectInitFailed", + "hipErrorSharedObjectSymbolNotFound", + "hipErrorStreamCaptureImplicit", + "hipErrorStreamCaptureInvalidated", + "hipErrorStreamCaptureIsolation", + "hipErrorStreamCaptureMerge", + "hipErrorStreamCaptureUnjoined", + "hipErrorStreamCaptureUnmatched", + "hipErrorStreamCaptureUnsupported", + "hipErrorStreamCaptureWrongThread", + "hipErrorTbd", + "hipErrorUnknown", + "hipErrorUnmapFailed", + "hipErrorUnsupportedLimit", + "hipError_t", + "hipFuncGetAttribute", + "hipFunction_attribute", + "hipFunction_t", + "hipGetErrorString", + "hipJitOption", + "hipJitOptionCacheMode", + "hipJitOptionErrorLogBuffer", + "hipJitOptionErrorLogBufferSizeBytes", + "hipJitOptionFallbackStrategy", + "hipJitOptionFastCompile", + "hipJitOptionGenerateDebugInfo", + "hipJitOptionGenerateLineInfo", + "hipJitOptionInfoLogBuffer", + "hipJitOptionInfoLogBufferSizeBytes", + "hipJitOptionLogVerbose", + "hipJitOptionMaxRegisters", + "hipJitOptionNumOptions", + "hipJitOptionOptimizationLevel", + "hipJitOptionSm3xOpt", + "hipJitOptionTarget", + "hipJitOptionTargetFromContext", + "hipJitOptionThreadsPerBlock", + "hipJitOptionWallTime", + "hipMemoryType", + "hipMemoryTypeArray", + "hipMemoryTypeDevice", + "hipMemoryTypeHost", + "hipMemoryTypeManaged", + "hipMemoryTypeUnified", + "hipMemoryTypeUnregistered", + "hipModuleGetFunction", + "hipModuleLaunchKernel", + "hipModuleLoadDataEx", + "hipModule_t", + "hipPointerAttribute_t", + "hipPointerGetAttributes", + "hipStream_t", + "hipSuccess", + "hipUUID", +] diff --git a/projects/eudsl-python-extras/examples/cuda_e2e.ipynb b/projects/eudsl-python-extras/examples/cuda_e2e.ipynb new file mode 100644 index 00000000..5f03bf51 --- /dev/null +++ b/projects/eudsl-python-extras/examples/cuda_e2e.ipynb @@ -0,0 +1,461 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "language": "python", + "display_name": "Python 3 (ipykernel)" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Based on [transform-mma-sync-matmul-f16-f16-accum.mlir](https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6)" + ], + "metadata": { + "collapsed": false + } + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Download mlir-python-bindings with CUDA support" + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "BRANCH = os.getenv(\"BRANCH\", \"main\")\n", + "os.environ[\"BRANCH\"] = BRANCH\n", + "os.environ[\"SCRIPT_ADDRESS\"] = f\"https://raw.githubusercontent.com/makslevental/mlir-python-extras/refs/heads/{BRANCH}/scripts/get_latest_bindings.py\"" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xh-QUDWiX-FD", + "outputId": "6865a63a-daa4-4610-e33a-721d37c0211f", + "ExecuteTime": { + "end_time": "2025-05-20T20:34:42.466337Z", + "start_time": "2025-05-20T20:34:42.464352Z" + } + }, + "outputs": [], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-20T20:34:49.667778Z", + "start_time": "2025-05-20T20:34:43.565625Z" + } + }, + "cell_type": "code", + "source": [ + "%%bash\n", + "curl $SCRIPT_ADDRESS -o get_latest_bindings.py\n", + "latest_cuda_version=$(python get_latest_bindings.py \"cuda\")\n", + "pip install -q mlir_python_bindings==$latest_cuda_version -f https://makslevental.github.io/wheels\n", + "pip install -q git+https://github.com/makslevental/mlir-python-extras@$BRANCH" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "100 2421 100 2421 0 0 36455 0 --:--:-- --:--:-- --:--:-- 36681\n" + ] + } + ], + "execution_count": 4 + }, + { + "cell_type": "markdown", + "source": [ + "# Boilerplate" + ], + "metadata": { + "id": "OSATAYhg7pSZ" + } + }, + { + "cell_type": "code", + "source": [ + "from pathlib import Path\n", + "\n", + "import mlir.extras.types as T\n", + "from mlir.dialects import builtin\n", + "from mlir.dialects.transform import any_op_t\n", + "from mlir.dialects.transform.extras import named_sequence\n", + "from mlir.dialects.transform.structured import MatchInterfaceEnum\n", + "from mlir.ir import StringAttr, UnitAttr\n", + "\n", + "from mlir import _mlir_libs\n", + "from mlir.extras.ast.canonicalize import canonicalize\n", + "from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule\n", + "from mlir.extras.dialects import arith, memref, scf, gpu\n", + "from mlir.extras.dialects import linalg\n", + "from mlir.extras.dialects import transform\n", + "from mlir.extras.dialects.func import func\n", + "from mlir.extras.runtime.passes import Pipeline, run_pipeline\n", + "from mlir.extras.runtime.refbackend import LLVMJITBackend\n", + "from mlir.extras.util import find_ops\n", + "\n", + "CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f\"libmlir_cuda_runtime.so\"\n", + "assert CUDA_RUNTIME_LIB_PATH.exists()" + ], + "metadata": { + "id": "_R-_0M5ZYO8p", + "ExecuteTime": { + "end_time": "2025-05-20T20:34:54.377581Z", + "start_time": "2025-05-20T20:34:54.374777Z" + } + }, + "outputs": [], + "execution_count": 6 + }, + { + "cell_type": "markdown", + "source": [ + "# Context" + ], + "metadata": { + "id": "s-JTcrjo7tNK" + } + }, + { + "cell_type": "code", + "source": [ + "ctx = RAIIMLIRContext()\n", + "module = ExplicitlyManagedModule()" + ], + "metadata": { + "id": "AGpWj9BzZLC_" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# Kernel and helper code" + ], + "metadata": { + "id": "qGcDtgkv71YB" + } + }, + { + "cell_type": "code", + "source": [ + "range_ = scf.range_\n", + "\n", + "M, K, N = 16, 16, 8\n", + "\n", + "# forward reference...\n", + "# TODO(max): figure out closures...\n", + "printMemrefF32_ = []\n", + "\n", + "\n", + "@func\n", + "def compute_linspace_val(ridx: T.index(), cidx: T.index(), stride_cidx: T.index()):\n", + " r = arith.index_cast(ridx, to=T.i32())\n", + " c = arith.index_cast(cidx, to=T.i32())\n", + " stride_c = arith.index_cast(stride_cidx, to=T.i32())\n", + " v2 = r * stride_c\n", + " v3 = c + v2\n", + " v4 = arith.sitofp(T.f16(), v3)\n", + " factor = arith.constant(64.0, T.f16())\n", + " v5 = arith.divf(v4, factor)\n", + " return v5\n", + "\n", + "\n", + "@func\n", + "@canonicalize(using=scf.canonicalizer)\n", + "def print_lhs_as_memref_32(lhs: T.memref(M, K, T.f16())):\n", + " M = memref.dim(lhs, 0)\n", + " K = memref.dim(lhs, 1)\n", + " tmp_alloc = memref.alloc((M, K), T.f32())\n", + " for m in range_(0, M):\n", + " for k in range_(0, K):\n", + " f16 = lhs[m, k]\n", + " f32 = arith.extf(T.f32(), f16)\n", + " tmp_alloc[m, k] = f32\n", + "\n", + " casted = memref.cast(T.memref(T.f32()), tmp_alloc)\n", + " printMemrefF32_[0](casted)\n", + " memref.dealloc(tmp_alloc)\n", + "\n", + "\n", + "@func\n", + "@canonicalize(using=scf.canonicalizer)\n", + "def print_rhs_as_memref_32(rhs: T.memref(K, N, T.f16())):\n", + " K = memref.dim(rhs, 0)\n", + " N = memref.dim(rhs, 1)\n", + " tmp_alloc = memref.alloc((K, N), T.f32())\n", + " for k in range_(0, K):\n", + " for n in range_(0, N):\n", + " f16 = rhs[k, n]\n", + " f32 = arith.extf(T.f32(), f16)\n", + " tmp_alloc[k, n] = f32\n", + "\n", + " casted = memref.cast(T.memref(T.f32()), tmp_alloc)\n", + " printMemrefF32_[0](casted)\n", + " memref.dealloc(tmp_alloc)\n", + "\n", + "\n", + "@func\n", + "@canonicalize(using=scf.canonicalizer)\n", + "def print_res_as_memref_32(res: T.memref(M, N, T.f16())):\n", + " c0 = arith.constant(0, index=True)\n", + " c1 = arith.constant(1, index=True)\n", + " M = memref.dim(res, c0)\n", + " N = memref.dim(res, c1)\n", + " tmp_alloc = memref.alloc((M, N), T.f32())\n", + " for m in range_(0, M):\n", + " for n in range_(0, N):\n", + " f16 = res[m, n]\n", + " f32 = arith.extf(T.f32(), f16)\n", + " tmp_alloc[m, n] = f32\n", + "\n", + " casted = memref.cast(T.memref(T.f32()), tmp_alloc)\n", + " printMemrefF32_[0](casted)\n", + " memref.dealloc(tmp_alloc)\n", + "\n", + "\n", + "@func\n", + "@canonicalize(using=scf.canonicalizer)\n", + "def main():\n", + " lhs = memref.alloc((M, K), T.f16())\n", + " rhs = memref.alloc((K, N), T.f16())\n", + " res = memref.alloc((M, N), T.f16())\n", + "\n", + " M_ = memref.dim(res, 0)\n", + " N_ = memref.dim(res, 1)\n", + " K_ = memref.dim(lhs, 1)\n", + "\n", + " _f1 = arith.constant(1.0e00, T.f16())\n", + " _f0 = arith.constant(0.0e00, T.f16())\n", + " _c32 = arith.constant(32, T.index())\n", + "\n", + " # Initialize the lhs matrix with a linspace function.\n", + " for r in range_(0, M_):\n", + " for c in range_(0, K_):\n", + " idx = compute_linspace_val(r, c, K_)\n", + " lhs[r, c] = idx\n", + "\n", + " # Initialize the rhs matrix with a linspace function.\n", + " for r in range_(0, K_):\n", + " for c in range_(0, N_):\n", + " idx = compute_linspace_val(r, c, N_)\n", + " rhs[r, c] = idx\n", + "\n", + " # Initialize the res matrix with a linspace function.\n", + " for r in range_(0, M_):\n", + " for c in range_(0, N_):\n", + " idx = compute_linspace_val(r, c, N_)\n", + " res[r, c] = idx\n", + "\n", + " ulhs = memref.cast(T.memref(T.f16()), lhs)\n", + " urhs = memref.cast(T.memref(T.f16()), rhs)\n", + " ures = memref.cast(T.memref(T.f16()), res)\n", + " gpu.host_register(ulhs)\n", + " gpu.host_register(urhs)\n", + " gpu.host_register(ures)\n", + "\n", + " print_lhs_as_memref_32(lhs)\n", + " print_rhs_as_memref_32(rhs)\n", + "\n", + " @gpu.launch(grid_size=[1, 1, 1], block_size=[32, 1, 1])\n", + " def kernel(bx, by, bz, tx, ty, tz, *grid_block_sizes):\n", + " linalg.matmul(lhs, rhs, res)\n", + "\n", + " print_res_as_memref_32(res)\n", + "\n", + "\n", + "@builtin.module(attrs={\"transform.target_tag\": StringAttr.get(\"payload\")})\n", + "def payload():\n", + " compute_linspace_val.emit()\n", + "\n", + " @func\n", + " def printMemrefF32(x: T.memref(T.f32())):\n", + " ...\n", + "\n", + " printMemrefF32_.append(printMemrefF32)\n", + "\n", + " print_lhs_as_memref_32.emit()\n", + " print_rhs_as_memref_32.emit()\n", + " print_res_as_memref_32.emit()\n", + " main.emit()" + ], + "metadata": { + "id": "7oQk4xJd72FI" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# Transform schedule\n" + ], + "metadata": { + "id": "a0vJZrpR74KB" + } + }, + { + "cell_type": "code", + "source": [ + "@builtin.module(attrs={\"transform.with_named_sequence\": UnitAttr.get()})\n", + "def mod_transform():\n", + " @named_sequence(\n", + " \"main\", [any_op_t()], [], arg_attrs=[{\"transform.readonly\": UnitAttr.get()}]\n", + " )\n", + " def main(module: any_op_t()):\n", + " matmul = transform.match(module, [\"linalg.matmul\"])\n", + " transform.nvgpu.rewrite_matmul_as_mma_sync(matmul)\n", + " # clean up to simplify test below...\n", + " all_loops = transform.match(\n", + " module, interface=MatchInterfaceEnum.LoopLikeInterface\n", + " )\n", + " transform.apply_licm(all_loops)\n", + " transform.apply_cse(module)" + ], + "metadata": { + "id": "EaBgGTIz72ci" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# \"Finish\" the module" + ], + "metadata": { + "id": "ADbabroS8ND2" + } + }, + { + "cell_type": "code", + "source": [ + "module = module.finish()\n", + "print(module)\n", + "assert module.operation.verify()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CUOsYXaW8QKC", + "outputId": "f8592229-1d9b-4c52-9133-30fd52c2716d" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# Execute the transform schedule" + ], + "metadata": { + "id": "0xN5kNvZ8Tyf" + } + }, + { + "cell_type": "code", + "source": [ + "mod = run_pipeline(\n", + " module,\n", + " Pipeline().transform_interpreter(\n", + " entry_point=\"main\", debug_payload_root_tag=\"payload\"\n", + " ),\n", + ")\n", + "print(mod)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lLwQLPD98Q4d", + "outputId": "ecfa6c9a-15eb-40c7-df29-f43fcac02fbf" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# Lower to NVVM (and LLVM)" + ], + "metadata": { + "id": "D_NURglF8ZZW" + } + }, + { + "cell_type": "code", + "source": [ + "CUDA_RUNTIME_EXISTS = Path(\"/usr/local/cuda\").exists()\n", + "if CUDA_RUNTIME_EXISTS:\n", + " backend = LLVMJITBackend([CUDA_RUNTIME_LIB_PATH])\n", + " # this doesn't actually anything (no pipeline) but does generate C API/wrappers\n", + " compiled_module = backend.compile(\n", + " find_ops(\n", + " mod.operation,\n", + " lambda x: \"transform.target_tag\" in x.attributes\n", + " and x.attributes[\"transform.target_tag\"].value == \"payload\",\n", + " single=True,\n", + " ),\n", + " Pipeline().add_pass(\n", + " \"gpu-lower-to-nvvm-pipeline\",\n", + " **{\n", + " \"cubin-chip\": \"sm_80\",\n", + " \"cubin-features\": \"+ptx76\",\n", + " \"cubin-format\": \"fatbin\",\n", + " },\n", + " ),\n", + " )\n", + " print(compiled_module)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9IoWjgc48bcn", + "outputId": "39550464-fd37-4e6d-a257-e803b746d8de" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# Load and run" + ], + "metadata": { + "id": "sOapyydH8n4h" + } + }, + { + "cell_type": "code", + "source": [ + "if CUDA_RUNTIME_EXISTS:\n", + " backend.load(compiled_module).main_capi_wrapper()" + ], + "metadata": { + "id": "pOEC4Qgw8p9X" + }, + "outputs": [], + "execution_count": null + } + ] +} diff --git a/projects/eudsl-python-extras/examples/cuda_matmul_opt.py b/projects/eudsl-python-extras/examples/cuda_matmul_opt.py new file mode 100644 index 00000000..7f48c833 --- /dev/null +++ b/projects/eudsl-python-extras/examples/cuda_matmul_opt.py @@ -0,0 +1,1234 @@ +import contextlib +import math + +import mlir.extras.types as T +import numpy as np +from mlir.dialects import builtin + +from util import cuda_bindings_not_installed +from mlir.extras.ast.canonicalize import canonicalize +from mlir.extras.context import ( + mlir_mod_ctx, + MLIRContext, +) +from mlir.extras.dialects import arith, memref, gpu, scf, linalg, vector, nvgpu +from mlir.extras.dialects.gpu import ( + block_idx, + thread_idx, + block_dim, + get_compile_object_bytes, + smem_space, +) +from mlir.extras.dialects.llvm import llvm_ptr_t +from mlir.extras.dialects.memref import S +from mlir.extras.dialects.scf import range_ +from mlir.extras.runtime.passes import Pipeline, run_pipeline + +# noinspection PyUnresolvedReferences +from mlir.extras.util import find_ops, enable_debug as enable_debug + +# just so it doesn't get DCE'd by black/reformat +_ = memref + + +def build_cuda_func(compiled_module, kernel_name="naive"): + from cupy.cuda import Module + + ptx = get_compile_object_bytes(compiled_module) + mod = Module() + mod.load(ptx) + return mod.get_function(kernel_name) + + +def print_ptx(compiled_module): + ptx = get_compile_object_bytes(compiled_module) + print(ptx.decode()) + + +def compile_module( + module, + chip="sm_80", + features="+ptx83", + opt_level=2, + enable_ir_printing=False, + print_ptx_=False, + full_pipeline=True, +): + if enable_ir_printing: + print_ptx_ = True + if full_pipeline: + p = ( + Pipeline() + .convert_linalg_to_loops() + .convert_nvgpu_to_nvvm() + .gpu_kernel_outlining() + .convert_vector_to_scf() + .convert_scf_to_cf() + .convert_nvvm_to_llvm() + .convert_func_to_llvm() + .expand_strided_metadata() + .add_pass( + "nvvm-attach-target", + **{ + "chip": chip, + "features": features, + "O": str(opt_level), + }, + ) + .lower_affine() + .convert_arith_to_llvm() + .convert_index_to_llvm() + .canonicalize() + .cse() + .Gpu( + Pipeline() + .strip_debuginfo() + # TODO(max): upstream this (add to gpu pipeline) + # vector.transfer + .convert_vector_to_llvm() + .convert_gpu_to_nvvm(use_bare_ptr_memref_call_conv=True) + .canonicalize() + .cse() + .reconcile_unrealized_casts() + ) + .gpu_to_llvm(use_bare_pointers_for_kernels=True) + .gpu_module_to_binary(format="isa") + .canonicalize() + .cse() + .reconcile_unrealized_casts() + ) + else: + p = Pipeline().add_pass( + "gpu-lower-to-nvvm-pipeline", + # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18 + **{ + "cubin-chip": chip, + "cubin-features": features, + "cubin-format": "isa", + "kernel-bare-ptr-calling-convention": "1", + "opt-level": str(opt_level), + # "cubin-format": "fatbin", + # "cubin-format": "bin", + }, + ) + mod = run_pipeline(module, p, enable_ir_printing=enable_ir_printing) + + if print_ptx_: + print_ptx(mod) + + return mod + + +@contextlib.contextmanager +def time_cuda(): + import cupy as cp + + start_gpu = cp.cuda.Event() + end_gpu = cp.cuda.Event() + + start_gpu.record() + yield start_gpu, end_gpu + end_gpu.record() + end_gpu.synchronize() + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_naive[ + M, + K, + N, + dtype, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + one = arith.constant(1.0, type=dtype) + tmp = arith.constant(0, type=dtype) + + # this is from the example and it's basically a mistake + # it increments the row for each adjacent thread id + # uncomment the print to see + r = block_dim.x * block_idx.x + thread_idx.x + c = block_dim.y * block_idx.y + thread_idx.y + # tid = gpu.thread_id() + # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c) + + for k, tmp, _ in range_(K, iter_args=[tmp]): + tmp += A[r, k] * B[k, c] + tmp = yield tmp + C[r, c] = tmp + one + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_naive_row_order[ + M, + K, + N, + dtype, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + one = arith.constant(1.0, type=dtype) + tmp = arith.constant(0, type=dtype) + + # increment along the cols (ie preserve row-order access) + c = block_dim.x * block_idx.x + thread_idx.x + r = block_dim.y * block_idx.y + thread_idx.y + # tid = gpu.thread_id() + # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c) + + for k, tmp, _ in range_(K, iter_args=[tmp]): + tmp += A[r, k] * B[k, c] + tmp = yield tmp + C[r, c] = tmp + one + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_coalesce[ + M, + K, + N, + dtype, + BLOCK_SIZE = 32, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + + tid = gpu.thread_id() + # this is actually floordiv + r = block_idx.x * BLOCK_SIZE + (tid / BLOCK_SIZE) + c = block_idx.y * BLOCK_SIZE + (tid % BLOCK_SIZE) + # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c) + + one = arith.constant(1.0, type=dtype) + tmp = arith.constant(0, type=dtype) + + for k, tmp, _ in range_(K, iter_args=[tmp]): + # k varies per core while c varies with tid + # apparently that's fine? i guess all the loads can happen + # because there's enough scratch per SM to prefetch all the data each thread needs? + tmp += A[r, k] * B[k, c] + tmp = yield tmp + C[r, c] = tmp + one + + +# So if you try to load something like: +# +# B.T: +# +# 0 0 0 0 0 0 0 0 +# 1 1 1 1 1 1 1 1 +# 2 2 2 2 2 2 2 2 +# +# vs +# +# B: +# 0 1 2 3 4 5 6 7 8 +# 0 1 2 3 4 5 6 7 8 +# 0 1 2 3 4 5 6 7 8 +# +# In B, you are feeding all threads with a single load (say warp can load 8 elements at a time) and then you increment k +# +# in B.T, a single load is feeding only a single thread, so others are probably waiting for their load to happen +# these are the issues by threads: +# +# 0: (0, 0), (1, 0), (2, 0) +# 1: (0, 1), (1, 1), (2, 1) +# 2: (0, 2), (1, 2), (2, 2) +# +# warp recieves these issues: +# +# (0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2) +# +# warp issues coalesced reads: +# +# (0, 0:2), (1, 0:2), (2,0:2) +# so even though the threads have bad memory access pattern +# the warp has good memory access pattern +# and since the actual load happens at warp level +# its good +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_coalesce_transpose_B[ + M, + K, + N, + dtype, + BLOCK_SIZE = 32, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + + tid = gpu.thread_id() + r = block_idx.x * BLOCK_SIZE + (tid / BLOCK_SIZE) + c = block_idx.y * BLOCK_SIZE + (tid % BLOCK_SIZE) + + one = arith.constant(1.0, type=dtype) + tmp = arith.constant(0, type=dtype) + + for k, tmp, _ in range_(K, iter_args=[tmp]): + # this is slower because c is incremented with each tid + # so you break memory coalescing + # but k now being on the row order dim doesn't help? + tmp += A[r, k] * B[c, k] + tmp = yield tmp + C[r, c] = tmp + one + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_shared_mem_block[ + M, + K, + N, + dtype, + BLOCK_SIZE = 32, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + # allocate buffer for current block in fast shared mem + # shared mem is shared between all threads in a block + base = gpu.dynamic_shared_memory() + A_shared = memref.view(base, (BLOCK_SIZE, BLOCK_SIZE), dtype=dtype) + B_shared = memref.view( + base, (BLOCK_SIZE, BLOCK_SIZE), dtype=dtype, shift=BLOCK_SIZE * BLOCK_SIZE + ) + + # the inner row & col that we're accessing in this thread + tid = gpu.thread_id() + thread_row = tid / BLOCK_SIZE + thread_col = tid % BLOCK_SIZE + + # the output block that we want to compute in this threadblock + c_row = block_idx.x * BLOCK_SIZE + c_col = block_idx.y * BLOCK_SIZE + + tmp = arith.constant(0, type=dtype) + + for bk_idx, tmp, _ in range_(0, K, BLOCK_SIZE, iter_args=[tmp]): + A_ = A[c_row : c_row + BLOCK_SIZE, bk_idx : bk_idx + BLOCK_SIZE] + B_ = B[bk_idx : bk_idx + BLOCK_SIZE, c_col : c_col + BLOCK_SIZE] + + # Have each thread load one of the elements in A & B + # Make the threadCol (=thread_idx.x) the consecutive index + # to allow global memory access coalescing + A_shared[thread_row, thread_col] = A_[thread_row, thread_col] + B_shared[thread_row, thread_col] = B_[thread_row, thread_col] + + # block threads in this block until cache is fully populated + gpu.barrier() + + # execute the dotproduct on the currently cached block + for dot_idx, tmp, _ in range_(BLOCK_SIZE, iter_args=[tmp]): + tmp += A_shared[thread_row, dot_idx] * B_shared[dot_idx, thread_col] + tmp = yield tmp + + # need to sync again at the end, to avoid faster threads + # fetching the next block into the cache before slower threads are done + gpu.barrier() + + tmp = yield tmp + + one = arith.constant(1.0, type=dtype) + C_ = C[c_row : c_row + BLOCK_SIZE, c_col : c_col + BLOCK_SIZE] + C_[thread_row, thread_col] = tmp + one + + +class CUDABindingsNotInstalled(Exception): + pass + + +def prepare_non_tiled_kernel(ctx: MLIRContext, kernel, M, K, N, BLOCK_SIZE=32): + dtype = T.f32() + npy_dtype = np.float32 + + gpu.set_container_module(ctx.module) + + @gpu.module("matmul", ["#nvvm.target"]) + def matmul_mod(): + kernel[M, K, N, dtype].emit() + + assert ctx.module.operation.verify() + + if cuda_bindings_not_installed(): + raise CUDABindingsNotInstalled() + + kernel_name = kernel.__name__ + compiled_module = compile_module(ctx.module) + cuda_func = build_cuda_func(compiled_module, kernel_name) + # print_ptx(compiled_module) + + grid_dims = (math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE)) + block_dims = (BLOCK_SIZE, BLOCK_SIZE) + + if "shared" in kernel_name: + shared_mem = 2 * BLOCK_SIZE * BLOCK_SIZE * npy_dtype().nbytes + else: + shared_mem = 0 + + return ( + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + "transpose_B" in kernel_name, + ) + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_shared_mem_1d_block_tiling[ + M, + K, + N, + dtype, + BM, + BN, + BK, + TM, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + base = gpu.dynamic_shared_memory() + A_shared = memref.view(base, (BM, BK), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + tid = gpu.thread_id() + thread_col = tid % BN + thread_row = tid / BN + + inner_col_A = tid % BK # warp-level GMEM coalescing + inner_row_A = tid / BK + inner_col_B = tid % BN # warp-level GMEM coalescing + inner_row_B = tid / BN + + thread_results = memref.alloca((TM,), dtype) + linalg.fill(0, thread_results) + + for bk_idx in range_(0, K, BK): + # Move blocktile to beginning of A's row and B's column + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + A_shared[inner_row_A, inner_col_A] = A_[inner_row_A, inner_col_A] + B_shared[inner_row_B, inner_col_B] = B_[inner_row_B, inner_col_B] + + gpu.barrier() + + for dot_idx in range_(BK): + tmp_B = B_shared[dot_idx, thread_col] + for res_idx, tmp_B, _ in range_(TM, iter_args=[tmp_B]): + thread_results[res_idx] += ( + A_shared[thread_row * TM + res_idx, dot_idx] * tmp_B + ) + yield tmp_B + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + C_ = C[c_row : c_row + BM, c_col : c_col + BN] + for res_idx in range_(TM): + C_[thread_row * TM + res_idx, thread_col] = thread_results[res_idx] + one + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_shared_mem_2d_block_tiling[ + M, + K, + N, + dtype, + BM, + BN, + BK, + TM, + TN, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + base = gpu.dynamic_shared_memory() + A_shared = memref.view(base, (BM, BK), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + total_results_blocktile = BM * BN + num_threads_blocktile = total_results_blocktile // (TM * TN) + + tid = gpu.thread_id() + # BN/TN are the number of threads to span a column + thread_col = tid % (BN // TN) + thread_row = tid / (BN // TN) + + inner_col_A = tid % BK # warp-level GMEM coalescing + inner_row_A = tid / BK + stride_A = num_threads_blocktile // BK + + inner_col_B = tid % BN # warp-level GMEM coalescing + inner_row_B = tid / BN + stride_B = num_threads_blocktile // BN + + thread_results = memref.alloca((TM, TN), dtype) + linalg.fill(0, thread_results) + + reg_M = memref.alloca((TM,), dtype) + linalg.fill(0, reg_M) + + reg_N = memref.alloca((TN,), dtype) + linalg.fill(0, reg_N) + + for bk_idx in range_(0, K, BK): + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + for load_offset in range_(0, BM, stride_A): + A_shared[inner_row_A + load_offset, inner_col_A] = A_[ + inner_row_A + load_offset, inner_col_A + ] + for load_offset in range_(0, BK, stride_B): + B_shared[inner_row_B + load_offset, inner_col_B] = B_[ + inner_row_B + load_offset, inner_col_B + ] + + gpu.barrier() + + for dot_idx in range_(BK): + for i in range_(TM): + reg_M[i] = A_shared[thread_row * TM + i, dot_idx] + for i in range_(TN): + reg_N[i] = B_shared[dot_idx, thread_col * TN + i] + + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + thread_results[res_idx_m, res_idx_n] += ( + reg_M[res_idx_m] * reg_N[res_idx_n] + ) + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + C_ = C[c_row : c_row + BM, c_col : c_col + BN] + + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + C_[thread_row * TM + res_idx_m, thread_col * TN + res_idx_n] = ( + thread_results[res_idx_m, res_idx_n] + one + ) + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_shared_mem_2d_block_tiling_vectorize[ + M, + K, + N, + dtype, + BM, + BN, + BK, + TM, + TN, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + VECTOR_WIDTH = 4 + DTYPE_WIDTH = dtype.width // 8 + + # ld.global.v4.u32 and st.global.v4.f32 emitted only input args are aligned + # alignment for cupy is 512 bytes https://github.com/cupy/cupy/blob/59e6c2b2e0c722b09c7a7af13f908942ef7806cc/cupy/cuda/memory.pyx#L805-L809 + # so we're good + memref.assume_alignment(A, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(B, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(C, VECTOR_WIDTH * DTYPE_WIDTH) + + base = gpu.dynamic_shared_memory() + base = memref.memory_space_cast(T.memref(S, element_type=T.i8()), base) + + # transpose A + A_shared = memref.view(base, (BK, BM), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + tid = gpu.thread_id() + # BN/TN are the number of threads to span a column + thread_col = tid % (BN // TN) + thread_row = tid / (BN // TN) + + # calculating the indices that this thread will load into SMEM + # we'll load 128bit / 32bit = 4 elements per thread at each step + inner_col_A = tid % (BK // VECTOR_WIDTH) # warp-level GMEM coalescing + inner_row_A = tid / (BK // VECTOR_WIDTH) + inner_col_B = tid % (BN // VECTOR_WIDTH) # warp-level GMEM coalescing + inner_row_B = tid / (BN // VECTOR_WIDTH) + + thread_results = memref.alloca((TM, TN), dtype) + linalg.fill(0, thread_results) + + reg_M = memref.alloca((TM,), dtype) + linalg.fill(0, reg_M) + + reg_N = memref.alloca((TN,), dtype) + linalg.fill(0, reg_N) + + for bk_idx in range_(0, K, BK): + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + A_vec = vector.load_( + A_, [inner_row_A, inner_col_A * VECTOR_WIDTH], T.vector(VECTOR_WIDTH, dtype) + ) + for j in range(VECTOR_WIDTH): + # transpose A while loading it + A_shared[inner_col_A * VECTOR_WIDTH + j, inner_row_A] = A_vec[j] + + B_vec = vector.load_( + B_, [inner_row_B, inner_col_B * VECTOR_WIDTH], T.vector(VECTOR_WIDTH, dtype) + ) + vector.store(B_vec, B_shared, [inner_row_B, inner_col_B * VECTOR_WIDTH]) + + gpu.barrier() + + for dot_idx in range_(BK): + for i in range_(TM): + reg_M[i] = A_shared[dot_idx, thread_row * TM + i] + + for i in range_(TN): + reg_N[i] = B_shared[dot_idx, thread_col * TN + i] + + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + thread_results[res_idx_m, res_idx_n] += ( + reg_M[res_idx_m] * reg_N[res_idx_n] + ) + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + C_ = C[c_row : c_row + BM, c_col : c_col + BN] + + for res_idx_m in range_(TM): + for res_idx_n in range_(0, TN, VECTOR_WIDTH): + tmp = vector.load_( + C_, + [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n], + T.vector(VECTOR_WIDTH, dtype), + ) + for j in range(VECTOR_WIDTH): + tmp[j] = thread_results[res_idx_m, res_idx_n + j] + one + vector.store( + tmp, C_, [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n] + ) + + +WARP_SIZE = 32 + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_warp_tiling[ + M, + K, + N, + dtype, + BM, + BN, + BK, + WM, + WN, + WNITER, + TM, + TN, + NUM_THREADS, + A_t = T.memref(M, K, dtype), + B_t = T.memref(K, N, dtype), + C_t = T.memref(M, N, dtype), +](A: A_t, B: B_t, C: C_t): + VECTOR_WIDTH = 4 + DTYPE_WIDTH = dtype.width // 8 + + tid = gpu.thread_id() + + # ld.global.v4.u32 and st.global.v4.f32 emitted only input args are aligned + # alignment for cupy is 512 bytes https://github.com/cupy/cupy/blob/59e6c2b2e0c722b09c7a7af13f908942ef7806cc/cupy/cuda/memory.pyx#L805-L809 + # so we're good + memref.assume_alignment(A, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(B, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(C, VECTOR_WIDTH * DTYPE_WIDTH) + + base = gpu.dynamic_shared_memory() + base = memref.memory_space_cast(T.memref(S, element_type=T.i8()), base) + + # transpose A + A_shared = memref.view(base, (BK, BM), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + # Placement of the warp in the threadblock tile + warp_idx = tid / WARP_SIZE + warp_row = warp_idx / (BN // WN) + warp_col = warp_idx % (BN // WN) + + # size of the warp subtile + WMITER = (WM * WN) // (WARP_SIZE * TM * TN * WNITER) + WSUBM = WM // WMITER + WSUBN = WN // WNITER + + # Placement of the thread in the warp subtile + thread_idx_in_warp = tid % WARP_SIZE + thread_col_in_warp = thread_idx_in_warp % (WSUBN // TN) + thread_row_in_warp = thread_idx_in_warp / (WSUBN // TN) + + # calculating the indices that this thread will load into SMEM + # we'll load 128bit / 32bit = 4 elements per thread at each step + inner_row_A = tid / (BK // VECTOR_WIDTH) + inner_col_A = tid % (BK // VECTOR_WIDTH) + row_stride_A = (NUM_THREADS * VECTOR_WIDTH) // BK + inner_row_B = tid / (BN // VECTOR_WIDTH) + inner_col_B = tid % (BN // VECTOR_WIDTH) + row_stride_B = NUM_THREADS // (BN // VECTOR_WIDTH) + + # allocate thread-local cache for results in registerfile + thread_results = memref.alloca((WMITER * TM, WNITER * TN), dtype) + linalg.fill(0, thread_results) + + reg_M = memref.alloca((WMITER, TM), dtype) + linalg.fill(0, reg_M) + + reg_N = memref.alloca((WNITER, TN), dtype) + linalg.fill(0, reg_N) + + for bk_idx in range_(0, K, BK): + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + for offset in range(0, BM - row_stride_A + 1, row_stride_A): + A_vec = vector.load_( + A_, + [inner_row_A + offset, inner_col_A * VECTOR_WIDTH], + T.vector(VECTOR_WIDTH, dtype), + ) + for j in range(VECTOR_WIDTH): + # transpose A while loading it + A_shared[inner_col_A * VECTOR_WIDTH + j, inner_row_A + offset] = A_vec[ + j + ] + + for offset in range(0, BK - row_stride_B + 1, row_stride_B): + B_vec = vector.load_( + B_, + [inner_row_B + offset, inner_col_B * VECTOR_WIDTH], + T.vector(VECTOR_WIDTH, dtype), + ) + vector.store( + B_vec, B_shared, [inner_row_B + offset, inner_col_B * VECTOR_WIDTH] + ) + + gpu.barrier() + + for dot_idx in range_(BK): + for w_sub_row_idx in range_(WMITER): + for i in range_(TM): + reg_M[w_sub_row_idx, i] = A_shared[ + dot_idx, + warp_row * WM + + w_sub_row_idx * WSUBM + + thread_row_in_warp * TM + + i, + ] + + for w_sub_col_idx in range_(WNITER): + for i in range_(TN): + reg_N[w_sub_col_idx, i] = B_shared[ + dot_idx, + warp_col * WN + + w_sub_col_idx * WSUBN + + thread_col_in_warp * TN + + i, + ] + + for w_sub_row_idx in range_(WMITER): + for w_sub_col_idx in range_(WNITER): + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + thread_results[ + w_sub_row_idx * TM + res_idx_m, + w_sub_col_idx * TN + res_idx_n, + ] += ( + reg_M[w_sub_row_idx, res_idx_m] + * reg_N[w_sub_col_idx, res_idx_n] + ) + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + + for w_sub_row_idx in range_(WMITER): + for w_sub_col_idx in range_(WNITER): + r = c_row + warp_row * WM + w_sub_row_idx * WSUBM + c = c_col + warp_col * WN + w_sub_col_idx * WSUBN + C_ = C[r : r + WSUBM, c : c + WSUBN] + for res_idx_m in range_(TM): + for res_idx_n in range_(0, TN, VECTOR_WIDTH): + tmp = vector.load_( + C_, + [ + thread_row_in_warp * TM + res_idx_m, + thread_col_in_warp * TN + res_idx_n, + ], + T.vector(VECTOR_WIDTH, dtype), + ) + for j in range(VECTOR_WIDTH): + tmp[j] = ( + thread_results[ + w_sub_row_idx * TM + res_idx_m, + w_sub_col_idx * TN + res_idx_n + j, + ] + + one + ) + vector.store( + tmp, + C_, + [ + thread_row_in_warp * TM + res_idx_m, + thread_col_in_warp * TN + res_idx_n, + ], + ) + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_tensor_core[ + M, + K, + N, + A_t = T.memref(M, K, T.f16()), + B_t = T.memref(K, N, T.f16()), + C_t = T.memref(M, N, T.f32()), + a_tma_t = llvm_ptr_t(), + b_tma_t = llvm_ptr_t(), +](A: A_t, B: B_t, C: C_t, a_tma: a_tma_t, b_tma: b_tma_t): + a_tma = builtin.unrealized_conversion_cast( + [ + nvgpu.TensorMapDescriptorType.get( + T.memref(128, 64, T.f16(), memory_space=smem_space()), + swizzle=int(nvgpu.TensorMapSwizzleKind.SWIZZLE_128B), + l2promo=int(nvgpu.TensorMapL2PromoKind.L2PROMO_NONE), + oob_fill=int(nvgpu.TensorMapOOBKind.OOB_ZERO), + interleave=int(nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE), + ) + ], + [a_tma], + ) + b_tma = builtin.unrealized_conversion_cast( + [ + nvgpu.TensorMapDescriptorType.get( + T.memref(64, 64, T.f16(), memory_space=smem_space()), + swizzle=int(nvgpu.TensorMapSwizzleKind.SWIZZLE_128B), + l2promo=int(nvgpu.TensorMapL2PromoKind.L2PROMO_NONE), + oob_fill=int(nvgpu.TensorMapOOBKind.OOB_ZERO), + interleave=int(nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE), + ) + ], + [b_tma], + ) + tid = gpu.thread_id() + is_thread_0 = tid == 0 + + mbarrier = nvgpu.mbarrier_create() + nvgpu.mbarrier_init(mbarrier, 1, 0, predicate=is_thread_0) + nvgpu.tma_prefetch_descriptor(a_tma) + nvgpu.tma_prefetch_descriptor(b_tma) + + base = gpu.dynamic_shared_memory() + + shift = 0 + A_shared = memref.view(base, (M, K), dtype=T.f16(), shift=shift) + shift += A_shared.n_elements + B_shared = memref.view(base, (K, N), dtype=T.f16(), shift=shift) + shift += B_shared.n_elements + + a = memref.view(base, (128, 64), dtype=T.f16(), shift=shift) + shift += a.n_elements + b1 = memref.view(base, (64, 64), dtype=T.f16(), shift=shift) + shift += b1.n_elements + b2 = memref.view(base, (64, 64), dtype=T.f16(), shift=shift) + + ta_count = a.n_elements + b1.n_elements + b2.n_elements + nvgpu.mbarrier_arrive_expect_tx(mbarrier, ta_count, 0, predicate=is_thread_0) + + nvgpu.tma_async_load( + a, + mbarrier, + a_tma, + coordinates=[0, 0], + mbar_id=0, + predicate=is_thread_0, + ) + nvgpu.tma_async_load( + b1, + mbarrier, + b_tma, + coordinates=[0, 0], + mbar_id=0, + predicate=is_thread_0, + ) + nvgpu.tma_async_load( + b2, + mbarrier, + b_tma, + coordinates=[64, 0], + mbar_id=0, + predicate=is_thread_0, + ) + nvgpu.mbarrier_try_wait_parity(mbarrier, mbar_id=0) + + accum = nvgpu.warpgroup_mma_init_accumulator( + nvgpu.warpgroup_accumulator_t(M, N, T.f32()) + ) + lhs = nvgpu.warpgroup_generate_descriptor( + nvgpu.warpgroup_descriptor(M, K, T.f16()), A_shared, a_tma + ) + rhs = nvgpu.warpgroup_generate_descriptor( + nvgpu.warpgroup_descriptor(K, N, T.f16()), B_shared, b_tma + ) + acc = nvgpu.warpgroup_mma(accum, lhs, rhs, transpose_b=True) + nvgpu.warpgroup_mma_store(acc, C) + + +def prepare_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): + dtype = T.f32() + npy_dtype = np.float32 + kernel_name = kernel.__name__ + + gpu.set_container_module(ctx.module) + + BK = 8 + TM = 8 + TN = 8 + if "2d" in kernel_name and M >= 128 and N >= 128: + BM = 128 + BN = 128 + else: + BM = 64 + BN = 64 + + @gpu.module("matmul", ["#nvvm.target"]) + def matmul_mod(): + kernel[M, K, N, dtype, BM, BN, BK, TM, TN].emit() + + assert ctx.module.operation.verify() + + if cuda_bindings_not_installed(): + raise CUDABindingsNotInstalled() + + compiled_module = compile_module(ctx.module) + cuda_func = build_cuda_func(compiled_module, kernel_name) + # print_ptx(compiled_module) + + grid_dims = (math.ceil(N / BN), math.ceil(M / BM)) + if "2d" in kernel_name: + block_dims = (BM // TM, BN // TN) + else: + block_dims = (BM // TM, BN) + + if "shared" in kernel_name: + shared_mem = ((BM * BK) + (BK * BN)) * npy_dtype().nbytes + else: + shared_mem = 0 + + return ( + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + False, + ) + + +def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): + dtype = T.f32() + npy_dtype = np.float32 + kernel_name = kernel.__name__ + + gpu.set_container_module(ctx.module) + + # Settings for A100 (looks like it works for 3070 too?) + NUM_THREADS = 128 + BN = 128 + BM = 64 + BK = 16 + WN = 64 + WM = 32 + WNITER = 1 + TN = 4 + TM = 4 + + @gpu.module("matmul", ["#nvvm.target"]) + def matmul_mod(): + kernel[M, K, N, dtype, BM, BN, BK, WM, WN, WNITER, TM, TN, NUM_THREADS].emit() + + # print(ctx.module) + assert ctx.module.operation.verify() + + if cuda_bindings_not_installed(): + raise CUDABindingsNotInstalled() + + compiled_module = compile_module(ctx.module) + cuda_func = build_cuda_func(compiled_module, kernel_name) + # print_ptx(compiled_module) + + grid_dims = (math.ceil(N / BN), math.ceil(M / BM)) + block_dims = (NUM_THREADS,) + shared_mem = ((BM * BK) + (BK * BN)) * npy_dtype().nbytes + + return ( + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + False, + ) + + +def prepare_tensor_core_kernel(ctx: MLIRContext, kernel, M, K, N): + dtype = T.f16() + npy_dtype = np.float16 + kernel_name = kernel.__name__ + + gpu.set_container_module(ctx.module) + + # Settings for A100 (looks like it works for 3070 too?) + NUM_THREADS = 128 + BN = 128 + BM = 64 + BK = 16 + WN = 64 + WM = 32 + WNITER = 1 + TN = 4 + TM = 4 + + @gpu.module("matmul", ["#nvvm.target"]) + def matmul_mod(): + kernel[M, K, N, dtype].emit() + + assert ctx.module.operation.verify() + + if cuda_bindings_not_installed(): + raise CUDABindingsNotInstalled() + + compiled_module = compile_module( + ctx.module, chip="sm_90a", opt_level=3, full_pipeline=False + ) + # cuda_func = build_cuda_func(compiled_module, kernel_name) + # print_ptx(compiled_module) + + grid_dims = (math.ceil(N / BN), math.ceil(M / BM)) + block_dims = (NUM_THREADS,) + shared_mem = ((BM * BK) + (BK * BN)) * npy_dtype().nbytes + + return ( + # cuda_func, + None, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + False, + ) + + +def run_eval( + M, + K, + N, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + repeat_times=None, +): + import cupy as cp + + if repeat_times is None: + repeat_times = 50 + + A = np.random.randint(0, 10, (M, K)).astype(npy_dtype) + B = np.random.randint(0, 10, (K, N)).astype(npy_dtype) + C = np.zeros((M, N)).astype(npy_dtype) + + dA = cp.asarray(A) + if transpose_B: + dB = cp.asarray(np.ascontiguousarray(B.T)) + else: + dB = cp.asarray(B) + dC = cp.asarray(C) + + cuda_func( + grid_dims, + block_dims, + (dA.data.ptr, dB.data.ptr, dC.data.ptr), + shared_mem=shared_mem, + ) + C = cp.asnumpy(dC) + if not np.array_equal(C, A @ B + 1): + print(A @ B + 1) + print(C) + assert False + if repeat_times < 1: + return + + with time_cuda() as (start_gpu, end_gpu): + for _ in range(repeat_times): + cuda_func( + grid_dims, + block_dims, + (dA.data.ptr, dB.data.ptr, dC.data.ptr), + shared_mem=shared_mem, + ) + + t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) + + print(f"t={t_gpu / repeat_times:.6f} ms") + + +sizes = [128, 256, 512, 1024] +repeats = None + +for k in [ + sgemm_naive, + sgemm_naive_row_order, + sgemm_coalesce, + sgemm_coalesce_transpose_B, + sgemm_shared_mem_block, +]: + print(f"\n{k.__name__}") + for s in sizes: + with ( + mlir_mod_ctx() as ctx, + # enable_debug() + ): + print(f"{s=}", end=" ") + try: + cuda_func, grid_dims, block_dims, shared_mem, npy_dtype, transpose_B = ( + prepare_non_tiled_kernel(ctx, k, s, s, s) + ) + run_eval( + s, + s, + s, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + ) + except CUDABindingsNotInstalled: + continue + + +for k in [ + sgemm_shared_mem_1d_block_tiling, + sgemm_shared_mem_2d_block_tiling, + sgemm_shared_mem_2d_block_tiling_vectorize, +]: + print(f"\n{k.__name__}") + for s in sizes: + with ( + mlir_mod_ctx() as ctx, + # enable_debug() + ): + print(f"{s=}", end=" ") + try: + cuda_func, grid_dims, block_dims, shared_mem, npy_dtype, transpose_B = ( + prepare_tiled_kernel(ctx, k, s, s, s) + ) + run_eval( + s, + s, + s, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + ) + except CUDABindingsNotInstalled: + continue + +print(f"\n{sgemm_warp_tiling.__name__}") +for s in sizes: + with ( + mlir_mod_ctx() as ctx, + # enable_debug() + ): + print(f"{s=}", end=" ") + try: + cuda_func, grid_dims, block_dims, shared_mem, npy_dtype, transpose_B = ( + prepare_warp_tiled_kernel(ctx, sgemm_warp_tiling, s, s, s) + ) + run_eval( + s, + s, + s, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + ) + except CUDABindingsNotInstalled: + continue + + +sizes = [128, 256] + +for s in sizes: + with ( + mlir_mod_ctx() as ctx, + # enable_debug() + ): + print(f"{s=}", end=" ") + try: + cuda_func, grid_dims, block_dims, shared_mem, npy_dtype, transpose_B = ( + prepare_tensor_core_kernel(ctx, sgemm_tensor_core, s, s, s) + ) + # run_eval( + # s, + # s, + # s, + # cuda_func, + # grid_dims, + # block_dims, + # shared_mem, + # npy_dtype, + # transpose_B, + # ) + except CUDABindingsNotInstalled: + continue diff --git a/projects/eudsl-python-extras/examples/flash_attention.py b/projects/eudsl-python-extras/examples/flash_attention.py new file mode 100644 index 00000000..e1e435f0 --- /dev/null +++ b/projects/eudsl-python-extras/examples/flash_attention.py @@ -0,0 +1,366 @@ +from pathlib import Path + +import mlir.extras.types as T +import numpy as np +from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr + +from mlir.extras.ast.canonicalize import canonicalize +from mlir.extras.context import RAIIMLIRContextModule +from mlir.extras.dialects import memref, scf, arith, gpu, llvm +from mlir.dialects import math + +# noinspection PyUnresolvedReferences +from mlir.extras.dialects.gpu import ( + block_idx, + thread_idx, + grid_dim, + func as gpu_func, + set_container_module, + module, + get_compile_object_bytes, +) +from mlir.extras.runtime.passes import run_pipeline, Pipeline +from mlir.extras.util import find_ops + +# noinspection PyUnresolvedReferences +from util import ( + hip_check, + launch_kernel, + hip_synchronize, + hip_bindings_not_installed, + get_hip_arch, +) + + +def init_copy_host_device(B, nh, N, d): + from hip import hip + + q_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) + k_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) + v_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) + l_h = np.zeros((B, nh, N), dtype=np.float32) + m_h = np.full((B, nh, N), float(np.finfo(np.float32).min), dtype=np.float32) + O_h = np.zeros_like(q_h, dtype=np.float32) + + host = [q_h, k_h, v_h, l_h, m_h, O_h] + device = [hip_check(hip.hipMalloc(h.size * h.itemsize)) for h in host] + + for dev, h in zip(device, host): + hip_check( + hip.hipMemcpy( + dev, h, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyHostToDevice + ) + ) + + return host, device + + +def copy_device_host(host, device): + from hip import hip + + for d, h in zip(device, host): + hip_check( + hip.hipMemcpy( + h, d, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyDeviceToHost + ) + ) + hip_check(hip.hipFree(d)) + + return host + + +# just so it doesn't get DCE'd by black/reformat +# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable +_ = memref + +ctx = RAIIMLIRContextModule() +set_container_module(ctx.module) + + +# just a default attr - actual target is set blow +@module("kernels", [f'#rocdl.target']) +def gpu_module(): + pass + + +ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0]) +ip.__enter__() + +Bc = 32 +Br = 32 + +B = 16 +nh = 12 +N = 128 +d = 128 + +softmax_scale = 1.0 / float(np.sqrt(d)) + + +def softmax(x, axis=None): + x_max = np.amax(x, axis=axis, keepdims=True) + exp_x_shifted = np.exp(x - x_max) + return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) + + +def manual_attn(q, k, v): + att = q @ k.transpose(0, 1, 3, 2) * (1.0 / float(np.sqrt(k.shape[-1]))) + att = softmax(att, axis=-1) + y = att @ v + return y + + +rank_reduce = memref.rank_reduce + + +# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu +@gpu_func(emit=True) +@canonicalize(using=[scf.canonicalizer, arith.canonicalizer]) +def flash_attention( + Q: T.memref(B, nh, N, d, T.f32()), + K: T.memref(B, nh, N, d, T.f32()), + V: T.memref(B, nh, N, d, T.f32()), + l: T.memref(B, nh, N, T.f32()), + m: T.memref(B, nh, N, T.f32()), + O: T.memref(B, nh, N, d, T.f32()), +): + tx = thread_idx.x + # batch idx, head_idx + bx, by = block_idx.x, block_idx.y + # gpu.printf("bx %ld, by %ld\n", bx, by) + + # Offset into Q,K,V,O,l,m - different for each batch and head + K = K[bx, by, :, :, rank_reduce] + V = V[bx, by, :, :, rank_reduce] + Q = Q[bx, by, :, :, rank_reduce] + O = O[bx, by, :, :, rank_reduce] + l = l[bx, by, :, rank_reduce] + m = m[bx, by, :, rank_reduce] + + # Define SRAM for Q,K,V,S + sram = gpu.dynamic_shared_memory() + Qi = memref.view(sram, (Br, d), dtype=T.f32()) + Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements) + Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements) + S = memref.view( + sram, + (Br, Bc), + dtype=T.f32(), + shift=Qi.n_elements + Kj.n_elements + Vj.n_elements, + ) + + for bc in scf.range_(0, N, Bc): + # Load Kj, Vj to SRAM + K_ = K[bc : bc + 1, :] + V_ = V[bc : bc + 1, :] + for x in scf.range_(0, d): + Kj[tx, x] = K_[tx, x] + Vj[tx, x] = V_[tx, x] + + for br in scf.range_(0, N, Br): + # Load Qi to SRAM, l and m to registers + Q_ = Q[br : br + 1, :] + for x in scf.range_(0, d): + Qi[tx, x] = Q_[tx, x] + + l_ = l[br : br + 1] + m_ = m[br : br + 1] + row_l_prev = l_[tx] + row_m_prev = m_[tx] + + # S = QK^T, row_m = rowmax(S) + row_m: T.f32() = float(np.finfo(np.float32).min) + for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]): + sum: T.f32() = 0.0 + for x, sum, _ in scf.range_(0, d, iter_args=[sum]): + sum += Qi[tx, x] * Kj[y, x] + sum = yield sum + + sum *= softmax_scale + S[tx, y] = sum + + if sum > row_m: + row_m_ = yield sum + else: + row_m_ = yield row_m + + row_m = yield row_m_ + + # P = exp(S - row_m), row_l = rowsum(P) + row_l: T.f32() = 0.0 + for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]): + S[tx, y] = math.exp(S[tx, y] - row_m) + row_l += S[tx, y] + row_l = yield row_l + + # Compute new m and l + row_m_new = arith.maximumf(row_m_prev, row_m) + row_l_new = ( + math.exp(row_m_prev - row_m_new) * row_l_prev + + math.exp(row_m - row_m_new) * row_l + ) + div = 1.0 / row_l_new + f1 = row_l_prev * math.exp(row_m_prev - row_m_new) + f2 = math.exp(row_m - row_m_new) + + # Write O, l, m to HBM + O_ = O[br : br + 1, :] + for x in scf.range_(0, d): + pv: T.f32() = 0.0 # Pij * Vj + for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]): + pv += S[tx, y] * Vj[y, x] + pv = yield pv + + O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv) + + l_[tx] = row_l_new + m_[tx] = row_m_new + + gpu.barrier() + + +ip.__exit__(None, None, None) + +assert gpu_module.operation.verify() +# print(gpu_module) + +sram_size = 4 * Bc * d * np.float32().itemsize + +launch_params = { + flash_attention.__name__: ( + (B, nh, 1), + (Bc, 1, 1), + sram_size, + ) +} + +simplified_module = run_pipeline( + ctx.module, + Pipeline() + .canonicalize() + .cse() + .loop_invariant_code_motion() + .loop_invariant_subset_hoisting() + .rocdl_attach_target(chip=get_hip_arch(), O=3, abi="500"), +) + +assert simplified_module.operation.verify() + +# print(simplified_module) +# exit() + +lowered_module = run_pipeline( + simplified_module, + Pipeline() + .Gpu( + Pipeline().convert_gpu_to_rocdl( + use_bare_ptr_memref_call_conv=True, + runtime="HIP", + ) + ) + .gpu_to_llvm() + .lower_to_llvm() + .ensure_debug_info_scope_on_llvm_func(emission_kind="Full"), + # .Nested("llvm.func", Pipeline().sroa()), +) + +assert lowered_module.operation.verify() + +# print(lowered_module) +gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp)) +for g in gep: + g.attributes["inbounds"] = UnitAttr.get() + +kernel_funcs = find_ops( + lowered_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp) +) +for k in kernel_funcs: + if k.sym_name.value != flash_attention.__name__: + continue + _, thread_dims, _ = launch_params[k.sym_name.value] + k.attributes["rocdl.max_flat_work_group_size"] = IntegerAttr.get( + T.index(), np.prod(thread_dims) + ) + +if hip_bindings_not_installed(): + exit() +from hip import hip + +output_format = "bin" +# output_format = "llvm" +# output_format = "isa" + +lowered_module = run_pipeline( + lowered_module, Pipeline().gpu_module_to_binary(format=output_format) +) +hsaco = get_compile_object_bytes(lowered_module) +if output_format in {"isa", "llvm", "offloading"}: + with open(Path(__file__).parent / f"flashattention.{output_format}", "wb") as f: + f.write(hsaco) + exit() + + +hip_module = hip_check(hip.hipModuleLoadData(hsaco)) + +stream = 0 + +times = { + flash_attention: 0, +} +runs = 32 +for kernel in times: + for i in range(runs): + function = hip_check( + hip.hipModuleGetFunction(hip_module, kernel.__name__.encode()) + ) + hip_check(hip.hipDeviceSynchronize()) + + ( + ( + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + ), + ( + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, + ), + shared_memory, + ) = launch_params[kernel.__name__] + + host, device = init_copy_host_device(B, nh, N, d) + q_h, k_h, v_h, *_ = host + correct = manual_attn(q_h, k_h, v_h) + + time_compute = launch_kernel( + function.as_c_void_p(), + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, + stream, + shared_memory, + *device, + ) + + *_, O_h = copy_device_host(host, device) + if not np.allclose(correct, O_h): + with np.printoptions(threshold=np.inf, linewidth=np.inf): + print( + "correct - output:\n", + correct.round() - O_h.round(), + ) + print(f"{kernel.__name__} failed\n") + else: + print(f"{kernel.__name__}: {time_compute:.03f}ms") + + times[kernel] += time_compute + +for k in times: + times[k] /= runs + +for k, v in times.items(): + print(f"{k.__name__}: {v:.03f}ms") diff --git a/projects/eudsl-python-extras/examples/mlir_python_extras.ipynb b/projects/eudsl-python-extras/examples/mlir_python_extras.ipynb new file mode 100644 index 00000000..dc2042f5 --- /dev/null +++ b/projects/eudsl-python-extras/examples/mlir_python_extras.ipynb @@ -0,0 +1,518 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "MVpw-wdNOFv0" + }, + "source": [ + "# Welcome to `eudsl-python-extras` enjoy your stay!\n", + "\n", + "more at https://github.com/llvm/eudsl/tree/main/projects/eudsl-python-extras" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install eudsl-python-extras mlir-python-bindings -f https://llvm.github.io/eudsl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# \"Boiler plate\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import mlir.extras.types as T\n", + "from mlir.extras.ast.canonicalize import canonicalize\n", + "from mlir.extras.context import mlir_mod_ctx\n", + "from mlir.extras.dialects.arith import constant\n", + "from mlir.extras.dialects.memref import S\n", + "from mlir.extras.dialects.func import func\n", + "from mlir.extras.dialects.scf import canonicalizer as scf, range_\n", + "from mlir.extras.runtime.passes import Pipeline, run_pipeline\n", + "from mlir.extras.runtime.refbackend import LLVMJITBackend\n", + "from mlir.ir import StridedLayoutAttr\n", + "\n", + "# you need this to register the memref value caster\n", + "# noinspection PyUnresolvedReferences\n", + "import mlir.extras.dialects.memref\n", + "\n", + "ctx_man = mlir_mod_ctx()\n", + "ctx = ctx_man.__enter__()\n", + "backend = LLVMJITBackend()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9ijSKRNSOQ9D" + }, + "source": [ + "# MWE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2bJ1wqMPNshV" + }, + "outputs": [], + "source": [ + "K = 10\n", + "memref_i64 = T.memref(K, K, T.i64())\n", + "\n", + "@func(emit=True)\n", + "@canonicalize(using=scf)\n", + "def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):\n", + " one = constant(1)\n", + " two = constant(2)\n", + " if one > two:\n", + " C[0, 0] = constant(3, T.i64())\n", + " else:\n", + " for i in range_(0, K):\n", + " for j in range_(0, K):\n", + " C[i, j] = A[i, j] * B[i, j]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xJDEigigOY09" + }, + "source": [ + "## `func`, `memref`, `scf`, and `arith` dialects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zBDx-j9RN3XX", + "outputId": "913b8bec-270b-4db0-f78e-650327678524" + }, + "outputs": [], + "source": [ + "run_pipeline(ctx.module, Pipeline().cse())\n", + "print(ctx.module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_P-E1f2aOm6y" + }, + "source": [ + "## Lower to `llvm` dialect" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dlbMF12mN5N0", + "outputId": "90262ae0-77d8-4ee1-d436-d1209c24ec85" + }, + "outputs": [], + "source": [ + "module = backend.compile(\n", + " ctx.module,\n", + " kernel_name=memfoo.__name__,\n", + " pipeline=Pipeline().bufferize().lower_to_llvm(),\n", + ")\n", + "print(module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Dc-HjIzhO6a9" + }, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZKTUiSksN8vM" + }, + "outputs": [], + "source": [ + "A = np.random.randint(0, 10, (K, K)).astype(np.int64)\n", + "B = np.random.randint(0, 10, (K, K)).astype(np.int64)\n", + "C = np.zeros((K, K), dtype=np.int64)\n", + "backend.load(module).memfoo(A, B, C)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TtdESiwEPDjt" + }, + "source": [ + "## Check the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3gj7xvY9OCpB", + "outputId": "eadbf595-8bad-4246-d265-56d0051ffa85" + }, + "outputs": [], + "source": [ + "print(C)\n", + "assert np.array_equal(A * B, C)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ywu5wuvxUVe-" + }, + "source": [ + "## Clean up after yourself" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lB_vHZcvUYVI" + }, + "outputs": [], + "source": [ + "ctx_man.__exit__(None, None, None);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UP9QlXHBQwEn" + }, + "source": [ + "# Slightly more complicated example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "amh_lbcZQzj6" + }, + "outputs": [], + "source": [ + "ctx_man = mlir_mod_ctx()\n", + "ctx = ctx_man.__enter__()\n", + "\n", + "K = 256\n", + "D = 32\n", + "\n", + "F = K // D\n", + "ranked_memref_kxk_f32 = T.memref(K, K, T.f32())\n", + "layout = StridedLayoutAttr.get(S, (K, 1))\n", + "ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)\n", + "\n", + "@func(emit=True)\n", + "@canonicalize(using=scf)\n", + "def tile(\n", + " A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32\n", + "):\n", + " for i in range_(0, D):\n", + " for j in range_(0, D):\n", + " C[i, j] = A[i, j] + B[i, j]\n", + "\n", + "@func(emit=True)\n", + "@canonicalize(using=scf)\n", + "def tiled_memfoo(\n", + " A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n", + "):\n", + " for i in range_(0, F):\n", + " for j in range_(0, F):\n", + " l = lambda l: l * D\n", + " r = lambda r: (r + 1) * D\n", + " a, b, c = (\n", + " A[l(i) : r(i), l(j) : r(j)],\n", + " B[l(i) : r(i), l(j) : r(j)],\n", + " C[l(i) : r(i), l(j) : r(j)],\n", + " )\n", + " tile(a, b, c)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yeplm5V6RoHC" + }, + "source": [ + "## `func`, `memref`, `scf`, and `arith` dialects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fkR6mN8ZRb6i", + "outputId": "3597d048-f24e-4cc5-b332-ad79f0e7bd18" + }, + "outputs": [], + "source": [ + "print(ctx.module)\n", + "module = run_pipeline(ctx.module, str(Pipeline().cse()))\n", + "print(module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6ObNEVeTR0dF" + }, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OahzCa2yR3AX" + }, + "outputs": [], + "source": [ + "module = backend.compile(\n", + " module,\n", + " kernel_name=tiled_memfoo.__name__,\n", + " pipeline=Pipeline().bufferize().lower_to_llvm(),\n", + ")\n", + "\n", + "A = np.random.randint(0, 10, (K, K)).astype(np.float32)\n", + "B = np.random.randint(0, 10, (K, K)).astype(np.float32)\n", + "C = np.zeros((K, K)).astype(np.float32)\n", + "\n", + "backend.load(module).tiled_memfoo(A, B, C)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "374nuYgWTSJL" + }, + "source": [ + "## Check your results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2biEtkdFTT_H", + "outputId": "47ba52c0-3d62-4319-a6b0-f23f3fd468e0" + }, + "outputs": [], + "source": [ + "print(C)\n", + "assert np.array_equal(A + B, C)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dkK6RuhUUfi6" + }, + "source": [ + "## Clean up after yourself" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zXH2qabvUhOR" + }, + "outputs": [], + "source": [ + "ctx_man.__exit__(None, None, None);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UfdO_xJDTzh-" + }, + "source": [ + "# Do it like the professionals" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7Dwvs8CBT2T9", + "outputId": "cfc10029-cd9d-45da-d043-a4ee726133ed" + }, + "outputs": [], + "source": [ + "ctx_man = mlir_mod_ctx()\n", + "ctx = ctx_man.__enter__()\n", + "\n", + "ranked_memref_kxk_f32 = T.memref(K, K, T.f32())\n", + "layout = StridedLayoutAttr.get(S, (K, 1))\n", + "ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)\n", + "\n", + "from mlir.extras.dialects import linalg\n", + "\n", + "@func(emit=True)\n", + "@canonicalize(using=scf)\n", + "def linalg_memfoo(\n", + " A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n", + "):\n", + " for i in range_(0, F):\n", + " for j in range_(0, F):\n", + " l = lambda l: l * D\n", + " r = lambda r: (r + 1) * D\n", + " a, b, c = (\n", + " A[l(i) : r(i), l(j) : r(j)],\n", + " B[l(i) : r(i), l(j) : r(j)],\n", + " C[l(i) : r(i), l(j) : r(j)],\n", + " )\n", + " linalg.add(a, b, c)\n", + "\n", + "module = run_pipeline(ctx.module, str(Pipeline().cse()))\n", + "print(module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AdUDJvlMVHNk" + }, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_2DLkkQXVD_6" + }, + "outputs": [], + "source": [ + "module = backend.compile(\n", + " module,\n", + " kernel_name=linalg_memfoo.__name__,\n", + " pipeline=Pipeline().convert_linalg_to_loops().bufferize().lower_to_llvm()\n", + ")\n", + "invoker = backend.load(module)\n", + "A = np.random.randint(0, 10, (K, K)).astype(np.float32)\n", + "B = np.random.randint(0, 10, (K, K)).astype(np.float32)\n", + "C = np.zeros((K, K)).astype(np.float32)\n", + "\n", + "backend.load(module).linalg_memfoo(A, B, C)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hupr7s5LVVpQ" + }, + "source": [ + "## Check your results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-22vHoGXVXfm", + "outputId": "838de41e-8670-45bf-8ee7-65d6e9b8eb1a" + }, + "outputs": [], + "source": [ + "print(C)\n", + "assert np.array_equal(A + B, C)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IZKkHZb2PKIB" + }, + "source": [ + "## Clean up after yourself" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XUvvimM-PHPq" + }, + "outputs": [], + "source": [ + "ctx_man.__exit__(None, None, None);" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0rc3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/projects/eudsl-python-extras/examples/mwe.py b/projects/eudsl-python-extras/examples/mwe.py new file mode 100644 index 00000000..5d056fcd --- /dev/null +++ b/projects/eudsl-python-extras/examples/mwe.py @@ -0,0 +1,166 @@ +import platform + +import numpy as np + +import mlir.extras.types as T +from mlir.dialects import builtin +from mlir.dialects.transform import any_op_t +from mlir.dialects.transform.extras import named_sequence, apply_patterns +from mlir.extras.util import find_ops +from mlir.ir import StringAttr, UnitAttr + +# you need this to register the memref value caster +# noinspection PyUnresolvedReferences +import mlir.extras.dialects.memref +from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule +from mlir.dialects.bufferization import LayoutMapOption +from mlir.dialects.transform.vector import ( + VectorContractLowering, + VectorMultiReductionLowering, + VectorTransferSplit, + VectorTransposeLowering, +) +from mlir.extras.dialects import linalg +from mlir.extras.dialects.func import func +from mlir.extras.dialects.transform import ( + match, + tile_to_scf_for, + get_parent_op, + transform_any_op_t, +) +from mlir.extras.dialects import transform +from mlir.extras.runtime.passes import Pipeline, run_pipeline +from mlir.extras.runtime.refbackend import LLVMJITBackend + +ctx = RAIIMLIRContext() +backend = LLVMJITBackend() +module = ExplicitlyManagedModule() + +M, K, N = 2, 4, 6 + + +@func +def matmul_tensors( + A: T.tensor(M, K, T.f32()), + B: T.tensor(K, N, T.f32()), + C: T.tensor(M, N, T.f32()), +): + return linalg.matmul(A, B, C) + + +@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")}) +def payload(): + matmul_tensors.emit(force=True) + + +@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()}) +def mod_transform(): + @named_sequence("main", [any_op_t()], []) + def main(module_op: any_op_t()): + matmul = match(module_op, ops=["linalg.matmul"]) + tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[2, 2, 2]) + transform.structured.vectorize_children_and_apply_patterns( + get_parent_op(transform_any_op_t(), tiled_matmul, isolated_from_above=True) + ) + new_mod = transform.bufferization.one_shot_bufferize( + module_op, + function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap, + bufferize_function_boundaries=True, + ) + + func_op = match(new_mod, ops=["func.func"]) + + @apply_patterns(func_op) + def pats(): + transform.apply_patterns.vector.lower_contraction( + lowering_strategy=VectorContractLowering.OuterProduct + ) + transform.apply_patterns.vector.transfer_permutation_patterns() + transform.apply_patterns.vector.lower_multi_reduction( + lowering_strategy=VectorMultiReductionLowering.InnerParallel + ) + transform.apply_patterns.vector.split_transfer_full_partial( + split_transfer_strategy=VectorTransferSplit.LinalgCopy + ) + transform.apply_patterns.vector.transfer_to_scf( + max_transfer_rank=1, full_unroll=True + ) + transform.apply_patterns.vector.lower_transfer(max_transfer_rank=1) + transform.apply_patterns.vector.lower_shape_cast() + transform.apply_patterns.vector.lower_transpose( + lowering_strategy=VectorTransposeLowering.Shuffle1D + ) + + +module = module.finish() +# print(module) + +vectorized_module = run_pipeline( + module, + pipeline=Pipeline().transform_interpreter( + entry_point="main", debug_payload_root_tag="payload" + ), +) + +# print(vectorized_module) + +# https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44 +lower_to_llvm = ( + Pipeline() + .Func( + Pipeline() + # Blanket-convert any remaining high-level vector ops to loops if any remain. + .convert_vector_to_scf() + # Blanket-convert any remaining linalg ops to loops if any remain. + .convert_linalg_to_loops() + ) + # Blanket-convert any remaining affine ops if any remain. + .lower_affine() + # Convert SCF to CF (always needed). + .convert_scf_to_cf() + # Sprinkle some cleanups. + .canonicalize() + .cse() + # Convert vector to LLVM (always needed). + .convert_vector_to_llvm() + # Convert Math to LLVM (always needed). + .Func(Pipeline().convert_math_to_llvm()) + # Expand complicated MemRef operations before lowering them. + .expand_strided_metadata() + # The expansion may create affine expressions. Get rid of them. + .lower_affine() + # Convert MemRef to LLVM (always needed). + .finalize_memref_to_llvm() + # Convert Func to LLVM (always needed). + .convert_func_to_llvm() + .convert_arith_to_llvm() + .convert_cf_to_llvm() + # Convert Index to LLVM (always needed). + .convert_index_to_llvm() + # Convert remaining unrealized_casts (always needed). + .reconcile_unrealized_casts() +) + + +compiled_module = backend.compile( + find_ops( + vectorized_module.operation, + lambda x: "transform.target_tag" in x.attributes + and x.attributes["transform.target_tag"].value == "payload", + single=True, + ), + kernel_name=matmul_tensors.__name__, + pipeline=lower_to_llvm, +) + +# print(compiled_module) + +A = np.random.randint(0, 10, (M, K)).astype(np.float32) +B = np.random.randint(0, 10, (K, N)).astype(np.float32) +C = np.zeros((M, N), dtype=np.float32) + +if platform.system().lower() == "emscripten": + exit() + +backend.load(compiled_module).matmul_tensors_capi_wrapper(A, B, C) +assert np.allclose(A @ B, C) diff --git a/projects/eudsl-python-extras/examples/rdna_matmul_opt.py b/projects/eudsl-python-extras/examples/rdna_matmul_opt.py new file mode 100644 index 00000000..4347dd8a --- /dev/null +++ b/projects/eudsl-python-extras/examples/rdna_matmul_opt.py @@ -0,0 +1,871 @@ +import numpy as np + +from mlir.extras.ast.canonicalize import canonicalize +from mlir.extras.context import RAIIMLIRContextModule +from mlir.extras.dialects import memref, scf, arith, gpu, llvm +from mlir.dialects import index as index_dialect +from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr, Attribute +import mlir.extras.types as T + +# noinspection PyUnresolvedReferences +from mlir.extras.dialects.gpu import ( + all_reduce, + wait, + thread_attr as thread, + block_idx, + thread_idx, + block_dim, + GPUModuleMeta, + func as gpu_func, + set_container_module, + launch, + all_reduce_, + module, + get_compile_object_bytes, + lds_space, +) +from mlir.extras.runtime.passes import run_pipeline, Pipeline +from mlir.extras.util import find_ops + +# noinspection PyUnresolvedReferences +from util import ( + hip_check, + launch_kernel, + hip_synchronize, + hip_bindings_not_installed, + get_hip_arch, +) + + +def time_to_gflops(time_ms, N): + return 1e-6 * (N * N * N * 2 + 3 * N * N) // time_ms + + +# just so it doesn't get DCE'd by black/reformat +# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable +_ = memref + +ctx = RAIIMLIRContextModule() +set_container_module(ctx.module) + + +# just a default attr - actual target is set blow +@module("kernels", [f'#rocdl.target']) +def gpu_module(): + pass + + +ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0]) +ip.__enter__() + +M, K, N = 1024, 1024, 1024 + + +@gpu_func(emit=True) +@canonicalize(using=scf.canonicalizer) +def kernel1_naive( + A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), C: T.memref(M, N, T.f32()) +): + row = block_idx.y * block_dim.y + thread_idx.y + col = block_idx.x * block_dim.x + thread_idx.x + if (arith.index_cast(row, to=T.i32()) < M) & ( + arith.index_cast(col, to=T.i32()) < N + ): + acc = arith.constant(0.0) + for k, acc, _ in scf.range_(K, iter_args=[acc]): + a = A[row, k] + b = B[k, col] + acc = llvm.intr_fmuladd(a, b, acc) + acc = yield acc + + C[row, col] = acc + + +launch_params = { + kernel1_naive.__name__: ( + (N // 16, N // 16, 1), + (16, 16, 1), + 0, + ) +} + +BN = BK = TILE_SIZE = 32 + +A_shared = memref.global_( + sym_name="A_shared_BN_BK_0", + type=T.memref(BN, BK, T.f32(), memory_space=lds_space()), + alignment=16, +) +B_shared = memref.global_( + sym_name="B_shared_BK_BN_0", + type=T.memref(BK, BN, T.f32(), memory_space=lds_space()), + alignment=16, +) + +dtype = T.f32() + + +@gpu_func(emit=True) +@canonicalize(using=scf.canonicalizer) +def kernel2_lds_shared_direct_load_globals( + A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), C: T.memref(M, N, T.f32()) +): + As = memref.get_global(A_shared) + Bs = memref.get_global(B_shared) + + row = block_idx.y * TILE_SIZE + thread_idx.y + col = block_idx.x * TILE_SIZE + thread_idx.x + + sum = arith.constant(0.0) + + for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]): + Bs[thread_idx.y, thread_idx.x] = B[thread_idx.y + t, col] + As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t] + + gpu.barrier() + + for k in range(BK): + a = As[thread_idx.y, k] + b = Bs[k, thread_idx.x] + sum = llvm.intr_fmuladd(a, b, sum) + + gpu.barrier() + + sum = yield sum + + C[row, col] = sum + + +launch_params[kernel2_lds_shared_direct_load_globals.__name__] = ( + (N // TILE_SIZE, N // TILE_SIZE, 1), + (TILE_SIZE, TILE_SIZE, 1), + 0, +) + + +@gpu_func(emit=True) +@canonicalize(using=scf.canonicalizer) +def kernel2_lds_shared_direct_dynamic( + A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), C: T.memref(M, N, T.f32()) +): + As = memref.get_global(A_shared) + Bs = memref.get_global(B_shared) + + row = block_idx.y * TILE_SIZE + thread_idx.y + col = block_idx.x * TILE_SIZE + thread_idx.x + + sum = arith.constant(0.0) + + for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]): + Bs[thread_idx.y, thread_idx.x] = B[thread_idx.y + t, col] + As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t] + + gpu.barrier() + + for k in range(BK): + a = As[thread_idx.y, k] + b = Bs[k, thread_idx.x] + sum = llvm.intr_fmuladd(a, b, sum) + + gpu.barrier() + + sum = yield sum + + C[row, col] = sum + + +launch_params[kernel2_lds_shared_direct_dynamic.__name__] = ( + (N // TILE_SIZE, N // TILE_SIZE, 1), + (TILE_SIZE, TILE_SIZE, 1), + 2 * TILE_SIZE * TILE_SIZE * T.f32().width // 8, +) + + +@gpu_func(emit=True) +@canonicalize(using=scf.canonicalizer) +def kernel2_lds_shared_subview( + A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), C: T.memref(M, N, T.f32()) +): + # allocate buffer for current block in fast shared mem + # shared mem is shared between all threads in a block + base = gpu.dynamic_shared_memory() + As = memref.view(base, (TILE_SIZE, TILE_SIZE), dtype=dtype) + Bs = memref.view( + base, (TILE_SIZE, TILE_SIZE), dtype=dtype, shift=TILE_SIZE * TILE_SIZE + ) + + # the inner row & col that we're accessing in this thread + tid = gpu.thread_id() + thread_row = tid / TILE_SIZE + thread_col = tid % TILE_SIZE + + # the output block that we want to compute in this threadblock + c_row = block_idx.x * TILE_SIZE + c_col = block_idx.y * TILE_SIZE + + tmp = arith.constant(0, type=dtype) + + for bk_idx, tmp, _ in scf.range_(0, K, TILE_SIZE, iter_args=[tmp]): + A_ = A[c_row : c_row + TILE_SIZE, bk_idx : bk_idx + TILE_SIZE] + B_ = B[bk_idx : bk_idx + TILE_SIZE, c_col : c_col + TILE_SIZE] + + # Have each thread load one of the elements in A & B + # Make the threadCol (=thread_idx.x) the consecutive index + # to allow global memory access coalescing + As[thread_row, thread_col] = A_[thread_row, thread_col] + Bs[thread_row, thread_col] = B_[thread_row, thread_col] + + # block threads in this block until cache is fully populated + gpu.barrier() + + # execute the dotproduct on the currently cached block + for dot_idx in range(TILE_SIZE): + a, b = As[thread_row, dot_idx], Bs[dot_idx, thread_col] + tmp = llvm.intr_fmuladd(a, b, tmp) + + # need to sync again at the end, to avoid faster threads + # fetching the next block into the cache before slower threads are done + gpu.barrier() + + tmp = yield tmp + + C_ = C[c_row : c_row + TILE_SIZE, c_col : c_col + TILE_SIZE] + C_[thread_row, thread_col] = tmp + + +launch_params[kernel2_lds_shared_subview.__name__] = ( + (N // TILE_SIZE, N // TILE_SIZE, 1), + (TILE_SIZE, TILE_SIZE, 1), + 2 * TILE_SIZE * TILE_SIZE * T.f32().width // 8, +) + +BLOCK_SIZE = 256 +# Block Tile size +BN = 128 +BM = 128 +# Number of Row or column we read per batch +BK = 8 + +A_shared = memref.global_( + sym_name="A_shared_BK_BM_1", + type=T.memref(BK, BM, T.f32(), memory_space=lds_space()), + alignment=16, +) +B_shared = memref.global_( + sym_name="B_shared_BK_BN_1", + type=T.memref(BK, BN, T.f32(), memory_space=lds_space()), + alignment=16, +) + + +@gpu_func(emit=True) +@canonicalize(using=scf.canonicalizer) +def kernel3_registers( + A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), C: T.memref(M, N, T.f32()) +): + # Block Tile size + BN = 128 + BM = 128 + # Number of Row or column we read per batch + BK = 8 + + # Thread Tile size + TN = 4 + TM = 4 + + nbWaves = BLOCK_SIZE // 32 + # Wave Tile size + WN = 64 + WM = BN * BM // nbWaves // WN + + # Number of wave on X & Y axis in the Block tile + nbWaveX = BN // WN + nbWaveY = BM // WM + + waveIndex = thread_idx.x // 32 + waveIdx = waveIndex % nbWaveX + waveIdy = waveIndex // nbWaveX + indexInWave = thread_idx.x % 32 + + # A wave is a block of 8x4 of the output matrix + nbThreadXPerWave = 8 + nbThreadYPerWave = 4 + + # Thread coordinates in Wave + idxInWave = indexInWave % nbThreadXPerWave + idyInWave = indexInWave // nbThreadXPerWave + + nbIterWaveN = WN // (nbThreadXPerWave * TN) + nbIterWaveM = WM // (nbThreadYPerWave * TM) + + # Wave Sub-tile size + SUBWN = WN // nbIterWaveN + SUBWM = WM // nbIterWaveM + + # Thread mapping to read BKxBN block from A + rAIdx = thread_idx.x % BK + rAIdy = thread_idx.x // BK + # Thread mapping to read BNxBK block from B + rBIdx = thread_idx.x % BN + rBIdy = thread_idx.x // BN + + strideReadB = BLOCK_SIZE // BN + strideReadA = BLOCK_SIZE // BK + nbReadsB = BN * BK // BLOCK_SIZE + nbReadsA = BM * BK // BLOCK_SIZE + + A_col = memref.alloca([nbIterWaveM * TM], T.f32()) + B_row = memref.alloca([nbIterWaveN * TN], T.f32()) + + As = memref.get_global(A_shared) + Bs = memref.get_global(B_shared) + + l = TM * nbIterWaveM * TN * nbIterWaveN + c_regs = memref.alloca([l], T.f32()) + + c_regs_idx = memref.extract_aligned_pointer_as_index(c_regs) + c_regs_i64 = arith.index_cast(c_regs_idx, T.i64()) + c_regs_ptr = llvm.inttoptr(llvm.llvm_ptr_t(), c_regs_i64) + + l_4 = llvm.mlir_constant(l * 4) + c_0 = llvm.mlir_constant(0, T.i8()) + llvm.intr_memset(c_regs_ptr, c_0, l_4, False) + + for kId in scf.range_(0, N, BK): + + for i in range(nbReadsB): + index_x = BN * block_idx.x + rBIdx + index_y = rBIdy + i * strideReadB + kId + Bs[index_y % BK, index_x % BN] = B[index_y, index_x] + + for i in range(nbReadsA): + index_x = rAIdx + kId + index_y = BM * block_idx.y + rAIdy + i * strideReadA + As[index_x % BK, index_y % BM] = A[index_y, index_x] + + gpu.barrier() + + for k in range(BK): + for iterWave in range(nbIterWaveN): + for i in range(TN): + index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i + B_row[iterWave * TN + i] = Bs[k, index] + + for iterWave in range(nbIterWaveM): + for i in range(TM): + index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i + A_col[iterWave * TM + i] = As[k, index] + + for iterWaveM in range(nbIterWaveM): + for iterWaveN in range(nbIterWaveN): + for yt in range(TM): + for xt in range(TN): + x = iterWaveN * TN + xt + y = iterWaveM * TM + yt + a = A_col[y] + b = B_row[x] + c = c_regs[y * TN * nbIterWaveN + x] + c = llvm.intr_fmuladd(a, b, c) + # c = llvm.intr_fma(a, b, c) + c_regs[y * TN * nbIterWaveN + x] = c + # c_regs[y * TN * nbIterWaveN + x] += a * b + + gpu.barrier() + + for iterWaveM in range(nbIterWaveM): + for iterWaveN in range(nbIterWaveN): + xOut = block_idx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave + yOut = block_idx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave + for yt in range(TM): + for xt in range(TN): + C[yOut + yt, xOut + xt] = c_regs[ + TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt) + ] + + +launch_params[kernel3_registers.__name__] = ( + (N // 128, N // 128, 1), + (BLOCK_SIZE, 1, 1), + 0, +) + + +@gpu_func(emit=True) +@canonicalize(using=scf.canonicalizer) +def kernel4_gmem_db( + A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), C: T.memref(M, N, T.f32()) +): + # Thread Tile size + TN = 4 + TM = 4 + + nbWaves = BLOCK_SIZE // 32 + # Wave Tile size + WN = 64 + WM = BN * BM // nbWaves // WN + + # Number of wave on X & Y axis in the Block tile + nbWaveX = BN // WN + + # A wave is a block of 8x4 of the output matrix + nbThreadXPerWave = 8 + nbThreadYPerWave = 4 + + nbIterWaveN = WN // (nbThreadXPerWave * TN) + nbIterWaveM = WM // (nbThreadYPerWave * TM) + + # Wave Sub-tile size + SUBWN = WN // nbIterWaveN + SUBWM = WM // nbIterWaveM + + strideReadB = BLOCK_SIZE // BN + strideReadA = BLOCK_SIZE // BK + nbReadsB = BN * BK // BLOCK_SIZE + nbReadsA = BM * BK // BLOCK_SIZE + + waveIndex = thread_idx.x // 32 + waveIdx = waveIndex % nbWaveX + waveIdy = waveIndex // nbWaveX + indexInWave = thread_idx.x % 32 + + # Thread coordinates in Wave + idxInWave = indexInWave % nbThreadXPerWave + idyInWave = indexInWave / nbThreadXPerWave + + # Thread mapping to read BKxBN block from A + rAIdx = thread_idx.x % BK + rAIdy = thread_idx.x // BK + # Thread mapping to read BNxBK block from B + rBIdx = thread_idx.x % BN + rBIdy = thread_idx.x // BN + + A_col = memref.alloca([nbIterWaveM * TM], T.f32()) + B_row = memref.alloca([nbIterWaveN * TN], T.f32()) + + As = memref.get_global(A_shared) + Bs = memref.get_global(B_shared) + + l = TM * nbIterWaveM * TN * nbIterWaveN + c_regs = memref.alloca([l], T.f32()) + + c_regs_idx = memref.extract_aligned_pointer_as_index(c_regs) + c_regs_i64 = arith.index_cast(c_regs_idx, T.i64()) + c_regs_ptr = llvm.inttoptr(llvm.llvm_ptr_t(), c_regs_i64) + + l_4 = llvm.mlir_constant(l * 4) + c_0 = llvm.mlir_constant(0, T.i8()) + llvm.intr_memset(c_regs_ptr, c_0, l_4, False) + + for i in range(nbReadsB): + index_x = BN * block_idx.x + rBIdx + index_y = rBIdy + i * strideReadB + Bs[index_y % BK, index_x % BN] = B[index_y, index_x] + + for i in range(nbReadsA): + index_x = rAIdx + index_y = BM * block_idx.y + rAIdy + i * strideReadA + As[index_x % BK, index_y % BM] = A[index_y, index_x] + + gpu.barrier() + + regA = memref.alloca([nbReadsA], T.f32()) + regB = memref.alloca([nbReadsB], T.f32()) + + N_minus_BK = arith.constant(N - BK, index=True) + + for kId in scf.range_(0, N, BK): + + pred = index_dialect.cmp(index_dialect.IndexCmpPredicate.SLT, kId, N_minus_BK) + if pred: + + for i in range(nbReadsB): + index_x = BN * block_idx.x + rBIdx + index_y = rBIdy + i * strideReadB + kId + BK + regB[i] = B[index_y, index_x] + + for i in range(nbReadsA): + index_x = rAIdx + kId + BK + index_y = BM * block_idx.y + rAIdy + i * strideReadA + regA[i] = A[index_y, index_x] + + for k in range(BK): + + for iterWave in range(nbIterWaveN): + for i in range(TN): + index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i + B_row[iterWave * TN + i] = Bs[k, index] + + for iterWave in range(nbIterWaveM): + for i in range(TM): + index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i + A_col[iterWave * TM + i] = As[k, index] + + for iterWaveM in range(nbIterWaveM): + for iterWaveN in range(nbIterWaveN): + for yt in range(TM): + for xt in range(TN): + x = iterWaveN * TN + xt + y = iterWaveM * TM + yt + a = A_col[y] + b = B_row[x] + c = c_regs[y * TN * nbIterWaveN + x] + c = llvm.intr_fmuladd(a, b, c) + c_regs[y * TN * nbIterWaveN + x] = c + + gpu.barrier() + + if pred: + + for i in range(nbReadsB): + index_x = BN * block_idx.x + rBIdx + index_y = rBIdy + i * strideReadB + kId + BK + Bs[index_y % BK, index_x % BN] = regB[i] + + for i in range(nbReadsA): + index_x = rAIdx + kId + BK + index_y = BM * block_idx.y + rAIdy + i * strideReadA + As[index_x % BK, index_y % BM] = regA[i] + + gpu.barrier() + + for iterWaveM in range(nbIterWaveM): + for iterWaveN in range(nbIterWaveN): + xOut = block_idx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave + yOut = block_idx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave + for yt in range(TM): + for xt in range(TN): + C[yOut + yt, xOut + xt] = c_regs[ + TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt) + ] + + +launch_params[kernel4_gmem_db.__name__] = ( + (N // 128, N // 128, 1), + (BLOCK_SIZE, 1, 1), + 0, +) + +A_shared = memref.global_( + sym_name="A_shared_BK_BM_times_4", + type=T.memref(BK, BM * 4, T.f32(), memory_space=lds_space()), + alignment=16, +) + + +@gpu_func(emit=True) +@canonicalize(using=scf.canonicalizer) +def kernel5_lds_optim( + A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), C: T.memref(M, N, T.f32()) +): + # Thread Tile size + TN = 4 + TM = 4 + + nbWaves = BLOCK_SIZE // 32 + # Wave Tile size + WN = 64 + WM = BN * BM // nbWaves // WN + + # Number of wave on X & Y axis in the Block tile + nbWaveX = BN // WN + + # A wave is a block of 8x4 of the output matrix + nbThreadXPerWave = 8 + nbThreadYPerWave = 4 + + nbIterWaveN = WN // (nbThreadXPerWave * TN) + nbIterWaveM = WM // (nbThreadYPerWave * TM) + + # Wave Sub-tile size + SUBWN = WN // nbIterWaveN + SUBWM = WM // nbIterWaveM + + strideReadB = BLOCK_SIZE // BN + strideReadA = BLOCK_SIZE // BK + nbReadsB = BN * BK // BLOCK_SIZE + nbReadsA = BM * BK // BLOCK_SIZE + + waveIndex = thread_idx.x // 32 + waveIdx = waveIndex % nbWaveX + waveIdy = waveIndex // nbWaveX + indexInWave = thread_idx.x % 32 + + # Thread coordinates in Wave + idxInWave = indexInWave % nbThreadXPerWave + idyInWave = indexInWave / nbThreadXPerWave + + # Thread mapping to read BKxBN block from A + rAIdx = thread_idx.x % BK + rAIdy = thread_idx.x // BK + # Thread mapping to read BNxBK block from B + rBIdx = thread_idx.x % BN + rBIdy = thread_idx.x // BN + + A_col = memref.alloca([nbIterWaveM * TM], T.f32()) + B_row = memref.alloca([nbIterWaveN * TN], T.f32()) + + As = memref.get_global(A_shared) + Bs = memref.get_global(B_shared) + + l = TM * nbIterWaveM * TN * nbIterWaveN + c_regs = memref.alloca([l], T.f32()) + + c_regs_idx = memref.extract_aligned_pointer_as_index(c_regs) + c_regs_i64 = arith.index_cast(c_regs_idx, T.i64()) + c_regs_ptr = llvm.inttoptr(llvm.llvm_ptr_t(), c_regs_i64) + + l_4 = llvm.mlir_constant(l * 4) + c_0 = llvm.mlir_constant(0, T.i8()) + llvm.intr_memset(c_regs_ptr, c_0, l_4, False) + + for i in range(nbReadsB): + index_x = BN * block_idx.x + rBIdx + index_y = rBIdy + i * strideReadB + Bs[index_y % BK, index_x % BN] = B[index_y, index_x] + + for i in range(nbReadsA): + index_x = rAIdx + index_y = BM * block_idx.y + rAIdy + i * strideReadA + As[index_x % BK, index_y % BM] = A[index_y, index_x] + + gpu.barrier() + + regA = memref.alloca([nbReadsA], T.f32()) + regB = memref.alloca([nbReadsB], T.f32()) + + for kId in scf.range_(0, N, BK): + + kId_i32 = arith.index_cast(kId, to=T.i32()) + if kId_i32 < (N - BK): + + for i in range(nbReadsB): + index_x = BN * block_idx.x + rBIdx + index_y = rBIdy + i * strideReadB + kId + BK + regB[i] = B[index_y, index_x] + + for i in range(nbReadsA): + index_x = rAIdx + kId + BK + index_y = BM * block_idx.y + rAIdy + i * strideReadA + regA[i] = A[index_y, index_x] + + for k in range(BK): + + for iterWave in range(nbIterWaveN): + for i in range(TN): + index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i + B_row[iterWave * TN + i] = Bs[k, index] + + for iterWave in range(nbIterWaveM): + for i in range(TM): + index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i + A_col[iterWave * TM + i] = As[k, index] + + for iterWaveM in range(nbIterWaveM): + for iterWaveN in range(nbIterWaveN): + for yt in range(TM): + for xt in range(TN): + x = iterWaveN * TN + xt + y = iterWaveM * TM + yt + a = A_col[y] + b = B_row[x] + c = c_regs[y * TN * nbIterWaveN + x] + c = llvm.intr_fmuladd(a, b, c) + c_regs[y * TN * nbIterWaveN + x] = c + + gpu.barrier() + + if kId_i32 < (N - BK): + + for i in range(nbReadsB): + index_x = BN * block_idx.x + rBIdx + index_y = rBIdy + i * strideReadB + kId + BK + Bs[index_y % BK, index_x % BN] = regB[i] + + for i in range(nbReadsA): + index_x = rAIdx + kId + BK + index_y = BM * block_idx.y + rAIdy + i * strideReadA + As[index_x % BK, index_y % BM] = regA[i] + + gpu.barrier() + + for iterWaveM in range(nbIterWaveM): + for iterWaveN in range(nbIterWaveN): + xOut = block_idx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave + yOut = block_idx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave + for yt in range(TM): + for xt in range(TN): + C[yOut + yt, xOut + xt] = c_regs[ + TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt) + ] + + +launch_params[kernel5_lds_optim.__name__] = ( + (N // 128, N // 128, 1), + (BLOCK_SIZE, 1, 1), + 0, +) + + +ip.__exit__(None, None, None) + +assert gpu_module.operation.verify() + +simplified_module = run_pipeline( + ctx.module, + Pipeline() + .canonicalize() + .cse() + .loop_invariant_code_motion() + .loop_invariant_subset_hoisting() + .rocdl_attach_target(chip=get_hip_arch(), O=3, abi="500"), +) + +assert simplified_module.operation.verify() +# print(simplified_module) + +lowered_module = run_pipeline( + simplified_module, + Pipeline() + .Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True)) + .gpu_to_llvm() + .lower_to_llvm(), + # .Nested("llvm.func", Pipeline().sroa()), +) + +assert lowered_module.operation.verify() +# print(lowered_module) + +gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp)) +for g in gep: + g.attributes["inbounds"] = UnitAttr.get() + + +kernel_funcs = find_ops( + lowered_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp) +) +target_flags = "+16-bit-insts,+atomic-fadd-rtn-insts,+ci-insts,+dl-insts,+dot10-insts,+dot5-insts,+dot7-insts,+dot8-insts,+dot9-insts,+dpp,+gfx10-3-insts,+gfx10-insts,+gfx11-insts,+gfx8-insts,+gfx9-insts,+wavefrontsize32".split( + "," +) +flags = ", ".join([f'"{t}"' for t in target_flags]) +for k in kernel_funcs: + _, thread_dims, _ = launch_params[k.sym_name.value] + k.attributes["rocdl.max_flat_work_group_size"] = IntegerAttr.get( + T.index(), np.prod(thread_dims) + ) + k.attributes["target_features"] = Attribute.parse( + f"#llvm.target_features<[{flags}]>" + ) + + +if hip_bindings_not_installed(): + exit() +from hip import hip + + +lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary()) +hsaco = get_compile_object_bytes(lowered_module) +# with open("/home/mlevental/dev_projects/fp32_sgemm_amd/pythonkernels.hsaco", "wb") as f: +# f.write(hsaco) +hip_module = hip_check(hip.hipModuleLoadData(hsaco)) + +a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np.float32) +b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np.float32) +# a_h = np.ones((M, K)).astype(dtype=np.float32) +# b_h = np.ones((M, K)).astype(dtype=np.float32) + +a_num_bytes = a_h.size * a_h.itemsize +b_num_bytes = b_h.size * b_h.itemsize + +a_d = hip_check(hip.hipMalloc(a_num_bytes)) +b_d = hip_check(hip.hipMalloc(b_num_bytes)) + +stream = 0 + +times = { + kernel1_naive: 0, + kernel2_lds_shared_subview: 0, + kernel2_lds_shared_direct_dynamic: 0, + kernel2_lds_shared_direct_load_globals: 0, + kernel3_registers: 0, + kernel4_gmem_db: 0, + kernel5_lds_optim: 0, +} +# random.shuffle(kernels) +runs = 16 +for kernel in times: + for i in range(runs): + function = hip_check( + hip.hipModuleGetFunction(hip_module, kernel.__name__.encode()) + ) + hip_check(hip.hipDeviceSynchronize()) + + hip_check( + hip.hipMemcpy( + a_d, a_h, a_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice + ) + ) + hip_check( + hip.hipMemcpy( + b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice + ) + ) + + c_h = -3 * np.ones((M, N), dtype=np.float32) + c_num_bytes = c_h.size * c_h.itemsize + c_d = hip_check(hip.hipMalloc(c_num_bytes)) + hip_check( + hip.hipMemcpy( + c_d, c_h, c_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice + ) + ) + + ( + ( + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + ), + ( + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, + ), + shared_memory, + ) = launch_params[kernel.__name__] + + time_compute = launch_kernel( + function.as_c_void_p(), + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, + stream, + shared_memory, + a_d, + b_d, + c_d, + ) + + hip_check( + hip.hipMemcpy( + c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost + ) + ) + correct = a_h @ b_h + if not np.allclose(correct, c_h): + # with np.printoptions(threshold=np.inf, linewidth=np.inf): + # print(correct) + # print(c_h) + print(f"{kernel.__name__} failed") + + times[kernel] += time_compute + + # print(f"{kernel.__name__} : {time_compute}") + +for k in times: + times[k] /= runs + +for k, v in times.items(): + print(f"{k.__name__}: {v:.03f}ms GLOPs {time_to_gflops(v, N)}") diff --git a/projects/eudsl-python-extras/examples/util.py b/projects/eudsl-python-extras/examples/util.py new file mode 100644 index 00000000..f82f8c5b --- /dev/null +++ b/projects/eudsl-python-extras/examples/util.py @@ -0,0 +1,174 @@ +import ctypes +import sys + + +def jax_not_installed(): + try: + from jaxlib import mlir + + # don't skip + return False + + except ImportError: + # skip + return True + + +def mlir_bindings_not_installed(): + try: + import mlir.extras + + # don't skip + return False + + except ImportError: + # skip + return True + + +def llvm_bindings_not_installed(): + try: + import llvm + + # don't skip + return False + + except ImportError: + # skip + return True + + +def hip_check(call_result): + from hip import hip + + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + return result + + +def hip_synchronize(): + from hip import hip + + hip.hipDeviceSynchronize() + + +def hip_bindings_not_installed(): + try: + from hip import hip + + props = hip.hipDeviceProp_t() + hip_check(hip.hipGetDeviceProperties(props, 0)) + + # don't skip + return False + + except ImportError: + return True + + except Exception as e: + print(e, file=sys.stderr) + # skip + return True + + +def cuda_bindings_not_installed(): + try: + import cupy as cp + import numpy as np + + A = np.random.randint(0, 10, (10, 10)) + dA = cp.asarray(A) + + # don't skip + return False + + except ImportError: + return True + + except Exception as e: + print(e, file=sys.stderr) + # skip + return True + + +def chip_check(status): + import chip + + if status != 0: + raise RuntimeError( + f"HIP Error {status}, {ctypes.string_at(chip.hipGetErrorString(status)).decode()}" + ) + + +def launch_kernel( + function, + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, + stream, + shared_memory, + *args, +): + import chip + + import hip + from hip._util.types import DeviceArray + + params = [None] * len(args) + addresses = [None] * len(args) + for i, p in enumerate(args): + if isinstance(p, DeviceArray): + addresses[i] = params[i] = p.createRef().as_c_void_p() + elif isinstance(p, int): + params[i] = ctypes.c_int32(p) + addresses[i] = ctypes.addressof(params[i]) + else: + raise NotImplementedError(f"{p=} not supported with {p=}") + + c_args = (ctypes.c_void_p * len(addresses))(*addresses) + function = ctypes.cast(function, chip.hipFunction_t) + stream = ctypes.cast(stream, chip.hipStream_t) + + tstart = hip_check(hip.hip.hipEventCreate()) + tstop = hip_check(hip.hip.hipEventCreate()) + hip_check(hip.hip.hipEventRecord(tstart, None)) + + r = chip.hipModuleLaunchKernel( + function, + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, + shared_memory, + stream, + c_args, + None, + ) + + hip_check(hip.hip.hipEventRecord(tstop, None)) + hip_check(hip.hip.hipEventSynchronize(tstop)) + time_compute = hip_check(hip.hip.hipEventElapsedTime(tstart, tstop)) + + chip_check(r) + + return time_compute + + +def get_hip_arch(): + if hip_bindings_not_installed(): + return "gfx1100" + + from hip import hip + + props = hip.hipDeviceProp_t() + hip_check(hip.hipGetDeviceProperties(props, 0)) + return props.gcnArchName.decode() diff --git a/projects/eudsl-python-extras/examples/vectorization_e2e.ipynb b/projects/eudsl-python-extras/examples/vectorization_e2e.ipynb new file mode 100644 index 00000000..18d64866 --- /dev/null +++ b/projects/eudsl-python-extras/examples/vectorization_e2e.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Welcome to `eudsl-python-extras` enjoy your stay!\n", + "\n", + "more at https://github.com/llvm/eudsl/tree/main/projects/eudsl-python-extras" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install eudsl-python-extras mlir-python-bindings -f https://llvm.github.io/eudsl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import mlir.extras.types as T\n", + "from mlir.dialects import builtin\n", + "from mlir.dialects.transform import any_op_t\n", + "from mlir.dialects.transform.extras import named_sequence, apply_patterns\n", + "from mlir.extras.util import find_ops\n", + "from mlir.ir import StringAttr, UnitAttr\n", + "\n", + "# you need this to register the memref value caster\n", + "# noinspection PyUnresolvedReferences\n", + "import mlir.extras.dialects.memref\n", + "from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule\n", + "from mlir.dialects.bufferization import LayoutMapOption\n", + "from mlir.dialects.transform.vector import (\n", + " VectorContractLowering,\n", + " VectorMultiReductionLowering,\n", + " VectorTransferSplit,\n", + " VectorTransposeLowering,\n", + ")\n", + "from mlir.extras.dialects import linalg\n", + "from mlir.extras.dialects.func import func\n", + "from mlir.extras.dialects.transform import (\n", + " match,\n", + " tile_to_scf_for,\n", + " get_parent_op,\n", + " transform_any_op_t,\n", + ")\n", + "from mlir.extras.dialects import transform\n", + "from mlir.extras.runtime.passes import Pipeline, run_pipeline\n", + "from mlir.extras.runtime.refbackend import LLVMJITBackend\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s-JTcrjo7tNK" + }, + "source": [ + "# Context" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AGpWj9BzZLC_" + }, + "outputs": [], + "source": [ + "ctx = RAIIMLIRContext()\n", + "module = ExplicitlyManagedModule()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qGcDtgkv71YB" + }, + "source": [ + "# Kernel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7oQk4xJd72FI" + }, + "outputs": [], + "source": [ + "M, K, N = 2, 4, 6\n", + "\n", + "\n", + "@func\n", + "def matmul_tensors(\n", + " A: T.tensor(M, K, T.f32()),\n", + " B: T.tensor(K, N, T.f32()),\n", + " C: T.tensor(M, N, T.f32()),\n", + "):\n", + " return linalg.matmul(A, B, C)\n", + "\n", + "@builtin.module(attrs={\"transform.target_tag\": StringAttr.get(\"payload\")})\n", + "def payload():\n", + " matmul_tensors.emit(force=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a0vJZrpR74KB" + }, + "source": [ + "# Transform schedule (based on [transform-e2e.mlir](https://github.com/llvm/llvm-project/blob/375bd2201ce0d2c76cb47a02c87b8ca5ba8a3509/mlir/test/Dialect/LLVM/transform-e2e.mlir))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EaBgGTIz72ci" + }, + "outputs": [], + "source": [ + "@builtin.module(attrs={\"transform.with_named_sequence\": UnitAttr.get()})\n", + "def mod_transform():\n", + " @named_sequence(\"main\", [any_op_t()], [])\n", + " def main(module_op: any_op_t()):\n", + " matmul = match(module_op, ops=[\"linalg.matmul\"])\n", + " tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[2, 2, 2])\n", + " transform.structured.vectorize_children_and_apply_patterns(\n", + " get_parent_op(transform_any_op_t(), tiled_matmul, isolated_from_above=True)\n", + " )\n", + " new_mod = transform.bufferization.one_shot_bufferize(\n", + " module_op,\n", + " function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap,\n", + " bufferize_function_boundaries=True,\n", + " )\n", + "\n", + " func_op = match(new_mod, ops=[\"func.func\"])\n", + "\n", + " @apply_patterns(func_op)\n", + " def pats():\n", + " transform.apply_patterns.vector.lower_contraction(\n", + " lowering_strategy=VectorContractLowering.OuterProduct\n", + " )\n", + " transform.apply_patterns.vector.transfer_permutation_patterns()\n", + " transform.apply_patterns.vector.lower_multi_reduction(\n", + " lowering_strategy=VectorMultiReductionLowering.InnerParallel\n", + " )\n", + " transform.apply_patterns.vector.split_transfer_full_partial(\n", + " split_transfer_strategy=VectorTransferSplit.LinalgCopy\n", + " )\n", + " transform.apply_patterns.vector.transfer_to_scf(\n", + " max_transfer_rank=1, full_unroll=True\n", + " )\n", + " transform.apply_patterns.vector.lower_transfer(max_transfer_rank=1)\n", + " transform.apply_patterns.vector.lower_shape_cast()\n", + " transform.apply_patterns.vector.lower_transpose(\n", + " lowering_strategy=VectorTransposeLowering.Shuffle1D\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ADbabroS8ND2" + }, + "source": [ + "# \"Finish\" the module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CUOsYXaW8QKC", + "outputId": "f8592229-1d9b-4c52-9133-30fd52c2716d" + }, + "outputs": [], + "source": [ + "module = module.finish()\n", + "print(module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0xN5kNvZ8Tyf" + }, + "source": [ + "# Vectorize (execute the transform schedule)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lLwQLPD98Q4d", + "outputId": "ecfa6c9a-15eb-40c7-df29-f43fcac02fbf" + }, + "outputs": [], + "source": [ + "vectorized_module = run_pipeline(\n", + " module,\n", + " pipeline=Pipeline().transform_interpreter(\n", + " entry_point=\"main\", debug_payload_root_tag=\"payload\"\n", + " ),\n", + ")\n", + "print(vectorized_module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D_NURglF8ZZW" + }, + "source": [ + "# Lower to CPU (through LLVM, based on [TestLowerToLLVM.cpp](https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9IoWjgc48bcn", + "outputId": "39550464-fd37-4e6d-a257-e803b746d8de" + }, + "outputs": [], + "source": [ + "lower_to_llvm = (\n", + " Pipeline()\n", + " .Func(\n", + " Pipeline()\n", + " # Blanket-convert any remaining high-level vector ops to loops if any remain.\n", + " .convert_vector_to_scf()\n", + " # Blanket-convert any remaining linalg ops to loops if any remain.\n", + " .convert_linalg_to_loops()\n", + " )\n", + " # Blanket-convert any remaining affine ops if any remain.\n", + " .lower_affine()\n", + " # Convert SCF to CF (always needed).\n", + " .convert_scf_to_cf()\n", + " # Sprinkle some cleanups.\n", + " .canonicalize()\n", + " .cse()\n", + " # Convert vector to LLVM (always needed).\n", + " .convert_vector_to_llvm()\n", + " # Convert Math to LLVM (always needed).\n", + " .Func(Pipeline().convert_math_to_llvm())\n", + " # Expand complicated MemRef operations before lowering them.\n", + " .expand_strided_metadata()\n", + " # The expansion may create affine expressions. Get rid of them.\n", + " .lower_affine()\n", + " # Convert MemRef to LLVM (always needed).\n", + " .finalize_memref_to_llvm()\n", + " # Convert Func to LLVM (always needed).\n", + " .convert_func_to_llvm()\n", + " .convert_arith_to_llvm()\n", + " .convert_cf_to_llvm()\n", + " # Convert Index to LLVM (always needed).\n", + " .convert_index_to_llvm()\n", + " # Convert remaining unrealized_casts (always needed).\n", + " .reconcile_unrealized_casts()\n", + ")\n", + "\n", + "backend = LLVMJITBackend()\n", + "compiled_module = backend.compile(\n", + " find_ops(\n", + " vectorized_module.operation,\n", + " lambda x: \"transform.target_tag\" in x.attributes\n", + " and x.attributes[\"transform.target_tag\"].value == \"payload\",\n", + " single=True,\n", + " ),\n", + " kernel_name=matmul_tensors.__name__,\n", + " pipeline=lower_to_llvm,\n", + ")\n", + "print(compiled_module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sOapyydH8n4h" + }, + "source": [ + "# Load, run, and compare against numpy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pOEC4Qgw8p9X" + }, + "outputs": [], + "source": [ + "A = np.random.randint(0, 10, (M, K)).astype(np.float32)\n", + "B = np.random.randint(0, 10, (K, N)).astype(np.float32)\n", + "C = np.zeros((M, N), dtype=np.float32)\n", + "\n", + "backend.load(compiled_module).matmul_tensors_capi_wrapper(A, B, C)\n", + "assert np.allclose(A @ B, C)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}