Skip to content

Commit 82a117e

Browse files
committed
Merge branch 'main' into metal-thread-safe
2 parents 24afa0b + fdadc4f commit 82a117e

File tree

20 files changed

+523
-183
lines changed

20 files changed

+523
-183
lines changed

mlx/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ target_sources(
55
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
66
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
77
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
8+
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
89
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
910
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
1011
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp

mlx/array.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,11 @@ class array {
339339
return allocator::allocator().size(buffer());
340340
}
341341

342-
// Return a copy of the shared pointer
343-
// to the array::Data struct
344-
std::shared_ptr<Data> data_shared_ptr() const {
342+
// Return the shared pointer to the array::Data struct
343+
const std::shared_ptr<Data>& data_shared_ptr() const {
345344
return array_desc_->data;
346345
}
346+
347347
// Return a raw pointer to the arrays data
348348
template <typename T>
349349
T* data() {

mlx/backend/metal/kernels/complex.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
104104
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
105105
return {a.real + b.real, a.imag + b.imag};
106106
}
107+
constexpr complex64_t operator+(float a, complex64_t b) {
108+
return {a + b.real, b.imag};
109+
}
110+
constexpr complex64_t operator+(complex64_t a, float b) {
111+
return {a.real + b, a.imag};
112+
}
107113

108114
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
109115
return {a.real - b.real, a.imag - b.imag};
110116
}
117+
constexpr complex64_t operator-(float a, complex64_t b) {
118+
return {a - b.real, -b.imag};
119+
}
120+
constexpr complex64_t operator-(complex64_t a, float b) {
121+
return {a.real - b, a.imag};
122+
}
111123

112124
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
113125
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
@@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
120132
return {x / denom, y / denom};
121133
}
122134

135+
constexpr complex64_t operator/(float a, complex64_t b) {
136+
auto denom = b.real * b.real + b.imag * b.imag;
137+
auto x = a * b.real;
138+
auto y = -a * b.imag;
139+
return {x / denom, y / denom};
140+
}
141+
123142
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
124143
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
125144
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));

mlx/backend/metal/kernels/unary.metal

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ instantiate_unary_float(Round)
6969
instantiate_unary_int(BitwiseInvert)
7070

7171
instantiate_unary_all_same(Abs, complex64, complex64_t)
72+
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
73+
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
74+
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
7275
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
7376
instantiate_unary_all_same(Cos, complex64, complex64_t)
7477
instantiate_unary_all_same(Cosh, complex64, complex64_t)
@@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t)
8083
instantiate_unary_all_same(Sign, complex64, complex64_t)
8184
instantiate_unary_all_same(Sin, complex64, complex64_t)
8285
instantiate_unary_all_same(Sinh, complex64, complex64_t)
86+
instantiate_unary_all_same(Square, complex64, complex64_t)
87+
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
88+
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
8389
instantiate_unary_all_same(Tan, complex64, complex64_t)
8490
instantiate_unary_all_same(Tanh, complex64, complex64_t)
8591
instantiate_unary_all_same(Round, complex64, complex64_t)

mlx/backend/metal/kernels/unary_ops.h

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,21 @@ struct Abs {
1717
T operator()(T x) {
1818
return metal::abs(x);
1919
};
20-
template <>
2120
uint8_t operator()(uint8_t x) {
2221
return x;
2322
};
24-
template <>
2523
uint16_t operator()(uint16_t x) {
2624
return x;
2725
};
28-
template <>
2926
uint32_t operator()(uint32_t x) {
3027
return x;
3128
};
32-
template <>
3329
uint64_t operator()(uint64_t x) {
3430
return x;
3531
};
36-
template <>
3732
bool operator()(bool x) {
3833
return x;
3934
};
40-
template <>
4135
complex64_t operator()(complex64_t x) {
4236
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
4337
};
@@ -48,6 +42,8 @@ struct ArcCos {
4842
T operator()(T x) {
4943
return metal::precise::acos(x);
5044
};
45+
46+
complex64_t operator()(complex64_t x);
5147
};
5248

5349
struct ArcCosh {
@@ -62,6 +58,8 @@ struct ArcSin {
6258
T operator()(T x) {
6359
return metal::precise::asin(x);
6460
};
61+
62+
complex64_t operator()(complex64_t x);
6563
};
6664

6765
struct ArcSinh {
@@ -76,6 +74,8 @@ struct ArcTan {
7674
T operator()(T x) {
7775
return metal::precise::atan(x);
7876
};
77+
78+
complex64_t operator()(complex64_t x);
7979
};
8080

8181
struct ArcTanh {
@@ -97,39 +97,30 @@ struct Ceil {
9797
T operator()(T x) {
9898
return metal::ceil(x);
9999
};
100-
template <>
101100
int8_t operator()(int8_t x) {
102101
return x;
103102
};
104-
template <>
105103
int16_t operator()(int16_t x) {
106104
return x;
107105
};
108-
template <>
109106
int32_t operator()(int32_t x) {
110107
return x;
111108
};
112-
template <>
113109
int64_t operator()(int64_t x) {
114110
return x;
115111
};
116-
template <>
117112
uint8_t operator()(uint8_t x) {
118113
return x;
119114
};
120-
template <>
121115
uint16_t operator()(uint16_t x) {
122116
return x;
123117
};
124-
template <>
125118
uint32_t operator()(uint32_t x) {
126119
return x;
127120
};
128-
template <>
129121
uint64_t operator()(uint64_t x) {
130122
return x;
131123
};
132-
template <>
133124
bool operator()(bool x) {
134125
return x;
135126
};
@@ -141,7 +132,6 @@ struct Cos {
141132
return metal::precise::cos(x);
142133
};
143134

144-
template <>
145135
complex64_t operator()(complex64_t x) {
146136
return {
147137
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
@@ -155,7 +145,6 @@ struct Cosh {
155145
return metal::precise::cosh(x);
156146
};
157147

158-
template <>
159148
complex64_t operator()(complex64_t x) {
160149
return {
161150
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
@@ -188,7 +177,6 @@ struct Exp {
188177
T operator()(T x) {
189178
return metal::precise::exp(x);
190179
};
191-
template <>
192180
complex64_t operator()(complex64_t x) {
193181
auto m = metal::precise::exp(x.real);
194182
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
@@ -207,39 +195,30 @@ struct Floor {
207195
T operator()(T x) {
208196
return metal::floor(x);
209197
};
210-
template <>
211198
int8_t operator()(int8_t x) {
212199
return x;
213200
};
214-
template <>
215201
int16_t operator()(int16_t x) {
216202
return x;
217203
};
218-
template <>
219204
int32_t operator()(int32_t x) {
220205
return x;
221206
};
222-
template <>
223207
int64_t operator()(int64_t x) {
224208
return x;
225209
};
226-
template <>
227210
uint8_t operator()(uint8_t x) {
228211
return x;
229212
};
230-
template <>
231213
uint16_t operator()(uint16_t x) {
232214
return x;
233215
};
234-
template <>
235216
uint32_t operator()(uint32_t x) {
236217
return x;
237218
};
238-
template <>
239219
uint64_t operator()(uint64_t x) {
240220
return x;
241221
};
242-
template <>
243222
bool operator()(bool x) {
244223
return x;
245224
};
@@ -258,7 +237,6 @@ struct Log {
258237
return metal::precise::log(x);
259238
};
260239

261-
template <>
262240
complex64_t operator()(complex64_t x) {
263241
auto r = metal::precise::log(Abs{}(x).real);
264242
auto i = metal::precise::atan2(x.imag, x.real);
@@ -272,7 +250,6 @@ struct Log2 {
272250
return metal::precise::log2(x);
273251
};
274252

275-
template <>
276253
complex64_t operator()(complex64_t x) {
277254
auto y = Log{}(x);
278255
return {y.real / M_LN2_F, y.imag / M_LN2_F};
@@ -285,7 +262,6 @@ struct Log10 {
285262
return metal::precise::log10(x);
286263
};
287264

288-
template <>
289265
complex64_t operator()(complex64_t x) {
290266
auto y = Log{}(x);
291267
return {y.real / M_LN10_F, y.imag / M_LN10_F};
@@ -325,7 +301,6 @@ struct Round {
325301
T operator()(T x) {
326302
return metal::rint(x);
327303
};
328-
template <>
329304
complex64_t operator()(complex64_t x) {
330305
return {metal::rint(x.real), metal::rint(x.imag)};
331306
};
@@ -344,11 +319,9 @@ struct Sign {
344319
T operator()(T x) {
345320
return (x > T(0)) - (x < T(0));
346321
};
347-
template <>
348322
uint32_t operator()(uint32_t x) {
349323
return x != 0;
350324
};
351-
template <>
352325
complex64_t operator()(complex64_t x) {
353326
if (x == complex64_t(0)) {
354327
return x;
@@ -364,7 +337,6 @@ struct Sin {
364337
return metal::precise::sin(x);
365338
};
366339

367-
template <>
368340
complex64_t operator()(complex64_t x) {
369341
return {
370342
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
@@ -378,7 +350,6 @@ struct Sinh {
378350
return metal::precise::sinh(x);
379351
};
380352

381-
template <>
382353
complex64_t operator()(complex64_t x) {
383354
return {
384355
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
@@ -398,13 +369,28 @@ struct Sqrt {
398369
T operator()(T x) {
399370
return metal::precise::sqrt(x);
400371
};
372+
373+
complex64_t operator()(complex64_t x) {
374+
if (x.real == 0.0 && x.imag == 0.0) {
375+
return {0.0, 0.0};
376+
}
377+
auto r = Abs{}(x).real;
378+
auto a = metal::precise::sqrt((r + x.real) / 2.0);
379+
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
380+
auto b = metal::copysign(b_abs, x.imag);
381+
return {a, b};
382+
}
401383
};
402384

403385
struct Rsqrt {
404386
template <typename T>
405387
T operator()(T x) {
406388
return metal::precise::rsqrt(x);
407389
};
390+
391+
complex64_t operator()(complex64_t x) {
392+
return 1.0 / Sqrt{}(x);
393+
}
408394
};
409395

410396
struct Tan {
@@ -413,7 +399,6 @@ struct Tan {
413399
return metal::precise::tan(x);
414400
};
415401

416-
template <>
417402
complex64_t operator()(complex64_t x) {
418403
float tan_a = metal::precise::tan(x.real);
419404
float tanh_b = metal::precise::tanh(x.imag);
@@ -429,7 +414,6 @@ struct Tanh {
429414
return metal::precise::tanh(x);
430415
};
431416

432-
template <>
433417
complex64_t operator()(complex64_t x) {
434418
float tanh_a = metal::precise::tanh(x.real);
435419
float tan_b = metal::precise::tan(x.imag);
@@ -438,3 +422,21 @@ struct Tanh {
438422
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
439423
};
440424
};
425+
426+
complex64_t ArcCos::operator()(complex64_t x) {
427+
auto i = complex64_t{0.0, 1.0};
428+
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
429+
return {y.imag, -y.real};
430+
};
431+
432+
complex64_t ArcSin::operator()(complex64_t x) {
433+
auto i = complex64_t{0.0, 1.0};
434+
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
435+
return {y.imag, -y.real};
436+
};
437+
438+
complex64_t ArcTan::operator()(complex64_t x) {
439+
auto i = complex64_t{0.0, 1.0};
440+
auto ix = i * x;
441+
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
442+
};

mlx/dtype_utils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/dtype_utils.h"
4+
5+
namespace mlx::core {
6+
7+
const char* dtype_to_string(Dtype arg) {
8+
if (arg == bool_) {
9+
return "bool";
10+
}
11+
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
12+
if (DTYPE == arg) { \
13+
return #DTYPE; \
14+
}
15+
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
16+
#undef SPECIALIZE_DtypeToString
17+
return "(unknown)";
18+
}
19+
20+
} // namespace mlx::core

0 commit comments

Comments
 (0)