Skip to content

Commit 0625fdc

Browse files
Merge pull request #33720 from olupton:thor
PiperOrigin-RevId: 840276978
2 parents 174d514 + 01513b2 commit 0625fdc

File tree

3 files changed

+51
-55
lines changed

3 files changed

+51
-55
lines changed

jax/_src/test_util.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,11 @@ def test_method_wrapper(self, *args, **kwargs):
545545
)
546546

547547
def get_cuda_nonportable_max_cluster_size():
548-
if device_kind_match("GB10$"):
549-
# 12 is the nonportable maximum cluster size on DGX Spark,
550-
# determined by querying cuOccupancyMaxPotentialClusterSize.
548+
# Per-device nonportable maximum cluster sizes for Jetson Thor and DGX
549+
# Spark (GB10) determined by querying cuOccupancyMaxPotentialClusterSize
550+
if device_kind_match("Thor$"):
551+
return 8
552+
elif device_kind_match("GB10$"):
551553
return 12
552554
# 16 is the nonportable maximum cluster size on:
553555
# - Hopper: https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#:~:text=cluster%20size%20of-,16,-by%20opting%20in

jaxlib/mosaic/gpu/runtime.cc

Lines changed: 43 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ limitations under the License.
2020
#include "third_party/gpus/cuda/include/cuda.h"
2121
#include "jaxlib/mosaic/gpu/nvshmem.h"
2222

23+
namespace {
24+
template <typename... Args>
25+
void abort_on_error(CUresult result, const char* fmt, Args&&... args) {
26+
if (result != CUDA_SUCCESS) {
27+
const char *ptr = nullptr;
28+
cuGetErrorString(result, &ptr);
29+
fprintf(stderr, fmt, std::forward<Args>(args)..., ptr);
30+
abort();
31+
}
32+
}
33+
}
34+
2335
extern "C" {
2436

2537
void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
@@ -159,27 +171,18 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
159171
fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes);
160172
abort();
161173
}
162-
CUresult result = cuTensorMapEncodeTiled(
174+
abort_on_error(
175+
cuTensorMapEncodeTiled(
163176
tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides,
164177
tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
165-
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
166-
if (result != CUDA_SUCCESS) {
167-
const char *ptr = nullptr;
168-
cuGetErrorString(result, &ptr);
169-
fprintf(stderr, "cuTensorMapEncodeTiled failed: %s\n", ptr);
170-
abort();
171-
}
178+
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE),
179+
"cuTensorMapEncodeTiled failed: %s\n");
172180
}
173181

174182
void* mosaic_gpu_module_load(void *data) {
175183
CUmodule module = nullptr;
176-
if (auto result = cuModuleLoadData(&module, data); result != CUDA_SUCCESS) {
177-
const char *ptr = nullptr;
178-
cuGetErrorString(result, &ptr);
179-
fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr);
180-
abort();
181-
}
182-
184+
abort_on_error(cuModuleLoadData(&module, data),
185+
"cuModuleLoadData failed: %s\n");
183186
{ // Set the NVSHMEM state if it's used by the module.
184187
CUdeviceptr ptr = 0;
185188
size_t size = 0;
@@ -200,41 +203,23 @@ void* mosaic_gpu_module_load(void *data) {
200203
void *mosaic_gpu_get_function(CUmodule module, const char *name,
201204
int32_t smem_bytes, int32_t cluster_size) {
202205
CUfunction function = nullptr;
203-
CUresult result = cuModuleGetFunction(&function, module, name);
204-
if (result != CUDA_SUCCESS) {
205-
const char *ptr = nullptr;
206-
cuGetErrorString(result, &ptr);
207-
fprintf(stderr,
208-
"Failed to retrieve function pointer to kernel \"%s\", "
209-
"cuModuleGetFunction failed: %s\n",
210-
name, ptr);
211-
abort();
212-
}
206+
abort_on_error(
207+
cuModuleGetFunction(&function, module, name),
208+
"Failed to retrieve function pointer to kernel \"%s\", "
209+
"cuModuleGetFunction failed: %s\n", name);
213210
if (smem_bytes) {
214-
result = cuFuncSetAttribute(
215-
function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes);
216-
if (result != CUDA_SUCCESS) {
217-
const char *ptr = nullptr;
218-
cuGetErrorString(result, &ptr);
219-
fprintf(stderr,
220-
"Failed to set maximum dynamic shared memory size for kernel "
221-
"\"%s\" to %d bytes, cuFuncSetAttribute failed: %s\n",
222-
name, smem_bytes, ptr);
223-
abort();
224-
}
211+
abort_on_error(
212+
cuFuncSetAttribute(
213+
function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes),
214+
"Failed to set maximum dynamic shared memory size for kernel \"%s\" "
215+
"to %d bytes, cuFuncSetAttribute failed: %s\n", name, smem_bytes);
225216
}
226217
if (cluster_size > 8) {
227-
result = cuFuncSetAttribute(
228-
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1);
229-
if (result != CUDA_SUCCESS) {
230-
const char *ptr = nullptr;
231-
cuGetErrorString(result, &ptr);
232-
fprintf(stderr,
233-
"Failed to set allowed cluster size for kernel \"%s\" to %d, "
234-
"cuFuncSetAttribute failed: %s\n",
235-
name, cluster_size, ptr);
236-
abort();
237-
}
218+
abort_on_error(
219+
cuFuncSetAttribute(
220+
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1),
221+
"Failed to set allowed cluster size for kernel \"%s\" to %d, "
222+
"cuFuncSetAttribute failed: %s\n", name, cluster_size);
238223
}
239224
return function;
240225
}
@@ -270,11 +255,18 @@ void mosaic_gpu_launch_kernel(CUfunction function, uint32_t grid_x,
270255
config.numAttrs = 1;
271256
}
272257
CUresult result = cuLaunchKernelEx(&config, function, params, nullptr);
273-
if (result != CUDA_SUCCESS) {
274-
const char *ptr = nullptr;
275-
cuGetErrorString(result, &ptr);
276-
fprintf(stderr, "cuLaunchKernel failed: %s\n", ptr);
258+
if (result == CUDA_ERROR_INVALID_CLUSTER_SIZE) {
259+
int max_cluster_size;
260+
abort_on_error(cuOccupancyMaxPotentialClusterSize(&max_cluster_size,
261+
function, &config),
262+
"cuOccupancyMaxPotentialClusterSize failed: %s\n");
263+
fprintf(stderr,
264+
"cuLaunchKernel failed with invalid cluster size (%d, %d, %d)"
265+
": maximum is %d\n", cluster_x, cluster_y, cluster_z,
266+
max_cluster_size);
277267
abort();
268+
} else {
269+
abort_on_error(result, "cuLaunchKernelEx: %s\n");
278270
}
279271
}
280272
}

tests/pallas/mosaic_gpu_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,9 @@ def kernel(x_ref, o_ref):
12041204
self.assertEqual(output(), "It works!\n")
12051205

12061206
def test_print_wgmma_tiled_layout(self):
1207-
shape = (128, 64)
1207+
# The default printf buffer on some smaller GPUs (e.g. Thor) only has space for
1208+
# 4096 threads to printf (short) messages. Keep this shape below that.
1209+
shape = (128, 32)
12081210
size = math.prod(shape)
12091211

12101212
@functools.partial(

0 commit comments

Comments
 (0)