Skip to content

Commit 9fa51db

Browse files
authored
[Serve] Enable GPU sampler for Metal (#3349)
Add "metal" to AttachGPUSamplingFunc transform_module list to include GPU sampling functions for models compiled for metal. Update SupportGPUSampler function to use GPU sampling functions for metal during runtime.
1 parent 157f87b commit 9fa51db

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

cpp/serve/sampler/sampler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ class Sampler : public ObjectRef {
144144
/*! \brief Check if the given device supports GPU sampling. */
145145
static bool SupportGPUSampler(Device device) {
146146
return device.device_type == DLDeviceType::kDLCUDA ||
147-
device.device_type == DLDeviceType::kDLVulkan;
147+
device.device_type == DLDeviceType::kDLVulkan ||
148+
device.device_type == DLDeviceType::kDLMetal;
148149
}
149150

150151
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sampler, ObjectRef, SamplerObj);

python/mlc_llm/compiler_pass/attach_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]):
2828

2929
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
3030
"""Entrypoint"""
31-
if str(self.target.kind) not in ["cuda", "vulkan"]:
32-
# Only enable GPU sampling for CUDA.
31+
if str(self.target.kind) not in ["cuda", "vulkan", "metal"]:
32+
# Only enable GPU sampling for CUDA, Vulkan, and Metal.
3333
return mod
3434

3535
bb = relax.BlockBuilder(mod)

0 commit comments

Comments
 (0)