Skip to content

Commit 28902ec

Browse files
committed
updated, simplified mutex for thread safety
1 parent e496c5a commit 28902ec

File tree

4 files changed

+19
-0
lines changed

4 files changed

+19
-0
lines changed

mlx/backend/metal/eval.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
#include "mlx/backend/gpu/available.h"
55
#include "mlx/backend/gpu/eval.h"
66
#include "mlx/backend/metal/device.h"
7+
#include "mlx/backend/metal/thread_safey.h"
78
#include "mlx/backend/metal/utils.h"
89
#include "mlx/primitives.h"
910
#include "mlx/scheduler.h"
1011

1112
namespace mlx::core::gpu {
1213

14+
std::mutex metal_operation_mutex;
15+
1316
bool is_available() {
1417
return true;
1518
}
@@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
3033
}
3134

3235
void eval(array& arr) {
36+
std::lock_guard<std::mutex> lock(metal_operation_mutex);
3337
auto pool = metal::new_scoped_memory_pool();
3438
auto s = arr.primitive().stream();
3539
auto& d = metal::device(s.device);
@@ -78,6 +82,7 @@ void eval(array& arr) {
7882
}
7983

8084
void finalize(Stream s) {
85+
std::lock_guard<std::mutex> lock(metal_operation_mutex);
8186
auto pool = metal::new_scoped_memory_pool();
8287
auto& d = metal::device(s.device);
8388
auto cb = d.get_command_buffer(s.index);
@@ -88,6 +93,7 @@ void finalize(Stream s) {
8893
}
8994

9095
void synchronize(Stream s) {
96+
std::lock_guard<std::mutex> lock(metal_operation_mutex);
9197
auto pool = metal::new_scoped_memory_pool();
9298
auto& d = metal::device(s.device);
9399
auto cb = d.get_command_buffer(s.index);

mlx/backend/metal/event.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "mlx/event.h"
44
#include "mlx/backend/metal/device.h"
5+
#include "mlx/backend/metal/thread_safey.h"
56
#include "mlx/scheduler.h"
67

78
namespace mlx::core {
@@ -27,6 +28,7 @@ void Event::wait(Stream stream) {
2728
if (stream.device == Device::cpu) {
2829
scheduler::enqueue(stream, [*this]() mutable { wait(); });
2930
} else {
31+
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
3032
auto& d = metal::device(stream.device);
3133
d.end_encoding(stream.index);
3234
auto command_buffer = d.get_command_buffer(stream.index);
@@ -41,6 +43,7 @@ void Event::signal(Stream stream) {
4143
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
4244
});
4345
} else {
46+
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
4447
auto& d = metal::device(stream.device);
4548
d.end_encoding(stream.index);
4649
auto command_buffer = d.get_command_buffer(stream.index);

mlx/backend/metal/fence.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright © 2024 Apple Inc.
22
#include "mlx/fence.h"
33
#include "mlx/backend/metal/device.h"
4+
#include "mlx/backend/metal/thread_safey.h"
45
#include "mlx/scheduler.h"
56
#include "mlx/utils.h"
67

@@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) {
6869
return;
6970
}
7071

72+
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
7173
auto& d = metal::device(stream.device);
7274
auto idx = stream.index;
7375

@@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) {
116118
return;
117119
}
118120

121+
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
119122
auto& d = metal::device(stream.device);
120123
auto idx = stream.index;
121124
if (!f.use_fast) {

mlx/backend/metal/thread_safey.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <mutex>
4+
5+
namespace mlx::core::gpu {
6+
extern std::mutex metal_operation_mutex;
7+
}

0 commit comments

Comments
 (0)