Skip to content

Commit 2e7c02d

Browse files
barronalexAlex Barronawni
authored
Metal FFT for powers of 2 up to 2048 (ml-explore#915)
* add Metal FFT for powers of 2 * skip GPU test on linux * fix contiguity bug * address comments * Update mlx/backend/metal/fft.cpp * Update mlx/backend/metal/fft.cpp * fix bug in synch --------- Co-authored-by: Alex Barron <[email protected]> Co-authored-by: Awni Hannun <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1 parent ae18326 commit 2e7c02d

File tree

6 files changed

+431
-31
lines changed

6 files changed

+431
-31
lines changed

benchmarks/python/fft_bench.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright © 2024 Apple Inc.
2+
3+
import matplotlib
4+
import mlx.core as mx
5+
import numpy as np
6+
from time_utils import measure_runtime
7+
8+
matplotlib.use("Agg")
9+
import matplotlib.pyplot as plt
10+
11+
12+
def bandwidth_gb(runtime_ms, system_size):
13+
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
14+
bytes_per_gb = 1e9
15+
ms_per_s = 1e3
16+
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
17+
18+
19+
def run_bench(system_size):
20+
def fft(x):
21+
out = mx.fft.fft(x)
22+
mx.eval(out)
23+
return out
24+
25+
bandwidths = []
26+
for k in range(4, 12):
27+
n = 2**k
28+
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
29+
x = x.astype(mx.complex64)
30+
mx.eval(x)
31+
runtime_ms = measure_runtime(fft, x=x)
32+
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
33+
34+
return bandwidths
35+
36+
37+
def time_fft():
38+
39+
with mx.stream(mx.cpu):
40+
cpu_bandwidths = run_bench(system_size=int(2**22))
41+
42+
with mx.stream(mx.gpu):
43+
gpu_bandwidths = run_bench(system_size=int(2**29))
44+
45+
# plot bandwidths
46+
x = [2**k for k in range(4, 12)]
47+
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
48+
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
49+
plt.title("MLX FFT Benchmark")
50+
plt.xlabel("N")
51+
plt.ylabel("Bandwidth (GB/s)")
52+
plt.legend()
53+
plt.savefig("fft_plot.png")
54+
55+
56+
if __name__ == "__main__":
57+
time_fft()

mlx/backend/metal/fft.cpp

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,106 @@
11
// Copyright © 2023 Apple Inc.
2-
2+
#include "mlx/backend/metal/copy.h"
3+
#include "mlx/backend/metal/utils.h"
4+
#include "mlx/mlx.h"
35
#include "mlx/primitives.h"
46

57
namespace mlx::core {
68

79
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
10+
auto& s = out.primitive().stream();
11+
auto& d = metal::device(s.device);
12+
813
auto& in = inputs[0];
9-
throw std::runtime_error("[FFT] NYI for Metal backend.");
14+
15+
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
16+
in.dtype() != complex64 || out.dtype() != complex64) {
17+
// Could also fallback to CPU implementation here.
18+
throw std::runtime_error(
19+
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
20+
}
21+
22+
size_t n = in.shape(axes_[0]);
23+
24+
if (!is_power_of_2(n) || n > 2048 || n < 4) {
25+
throw std::runtime_error(
26+
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
27+
}
28+
29+
// Make sure that the array is contiguous and has stride 1 in the FFT dim
30+
std::vector<array> copies;
31+
auto check_input = [this, &copies, &s](const array& x) {
32+
// TODO: Pass the strides to the kernel so
33+
// we can avoid the copy when x is not contiguous.
34+
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
35+
x.flags().col_contiguous;
36+
if (no_copy) {
37+
return x;
38+
} else {
39+
array x_copy(x.shape(), x.dtype(), nullptr, {});
40+
std::vector<size_t> strides;
41+
size_t cur_stride = x.shape(axes_[0]);
42+
for (int axis = 0; axis < x.ndim(); axis++) {
43+
if (axis == axes_[0]) {
44+
strides.push_back(1);
45+
} else {
46+
strides.push_back(cur_stride);
47+
cur_stride *= x.shape(axis);
48+
}
49+
}
50+
51+
auto flags = x.flags();
52+
size_t f_stride = 1;
53+
size_t b_stride = 1;
54+
flags.col_contiguous = true;
55+
flags.row_contiguous = true;
56+
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
57+
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
58+
f_stride *= x.shape(i);
59+
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
60+
b_stride *= x.shape(ri);
61+
}
62+
// This is probably over-conservative
63+
flags.contiguous = false;
64+
65+
x_copy.set_data(
66+
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
67+
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
68+
copies.push_back(x_copy);
69+
return x_copy;
70+
}
71+
};
72+
const array& in_contiguous = check_input(inputs[0]);
73+
74+
// TODO: allow donation here
75+
out.set_data(
76+
allocator::malloc_or_wait(out.nbytes()),
77+
in_contiguous.data_size(),
78+
in_contiguous.strides(),
79+
in_contiguous.flags());
80+
81+
// We use n / 4 threads by default since radix-4
82+
// is the largest single threaded radix butterfly
83+
// we currently implement.
84+
size_t m = n / 4;
85+
size_t batch = in.size() / in.shape(axes_[0]);
86+
87+
auto& compute_encoder = d.get_command_encoder(s.index);
88+
{
89+
std::ostringstream kname;
90+
kname << "fft_" << n;
91+
auto kernel = d.get_kernel(kname.str());
92+
93+
bool donated = in.data_shared_ptr() == nullptr;
94+
compute_encoder->setComputePipelineState(kernel);
95+
compute_encoder.set_input_array(in_contiguous, 0);
96+
compute_encoder.set_output_array(out, 1);
97+
98+
auto group_dims = MTL::Size(1, m, 1);
99+
auto grid_dims = MTL::Size(batch, m, 1);
100+
compute_encoder->dispatchThreads(grid_dims, group_dims);
101+
}
102+
d.get_command_buffer(s.index)->addCompletedHandler(
103+
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
10104
}
11105

12106
} // namespace mlx::core

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ set(
2121
"binary_two"
2222
"conv"
2323
"copy"
24+
"fft"
2425
"gemv"
2526
"quantized"
2627
"random"
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
// Metal FFT using Stockham's algorithm
4+
//
5+
// References:
6+
// - VkFFT (https://github.com/DTolm/VkFFT)
7+
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
8+
9+
#include <metal_math>
10+
#include <metal_common>
11+
12+
13+
#include "mlx/backend/metal/kernels/defines.h"
14+
#include "mlx/backend/metal/kernels/utils.h"
15+
16+
using namespace metal;
17+
18+
float2 complex_mul(float2 a, float2 b) {
19+
float2 c;
20+
c.x = a.x * b.x - a.y * b.y;
21+
c.y = a.x * b.y + a.y * b.x;
22+
return c;
23+
}
24+
25+
float2 get_twiddle(int k, int p) {
26+
float theta = -1.0f * k * M_PI_F / (2*p);
27+
28+
float2 twiddle;
29+
twiddle.x = metal::fast::cos(theta);
30+
twiddle.y = metal::fast::sin(theta);
31+
return twiddle;
32+
}
33+
34+
// single threaded radix2 implemetation
35+
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
36+
float2 x_0 = read_buf[i];
37+
float2 x_1 = read_buf[i + m];
38+
39+
// The index within this sub-DFT
40+
int k = i & (p - 1);
41+
42+
float2 twiddle = get_twiddle(k, p);
43+
44+
float2 z = complex_mul(x_1, twiddle);
45+
46+
float2 y_0 = x_0 + z;
47+
float2 y_1 = x_0 - z;
48+
49+
int j = (i << 1) - k;
50+
51+
write_buf[j] = y_0;
52+
write_buf[j + p] = y_1;
53+
}
54+
55+
// single threaded radix4 implemetation
56+
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
57+
float2 x_0 = read_buf[i];
58+
float2 x_1 = read_buf[i + m];
59+
float2 x_2 = read_buf[i + 2*m];
60+
float2 x_3 = read_buf[i + 3*m];
61+
62+
// The index within this sub-DFT
63+
int k = i & (p - 1);
64+
65+
float2 twiddle = get_twiddle(k, p);
66+
// e^a * e^b = e^(a + b)
67+
float2 twiddle_2 = complex_mul(twiddle, twiddle);
68+
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
69+
70+
x_1 = complex_mul(x_1, twiddle);
71+
x_2 = complex_mul(x_2, twiddle_2);
72+
x_3 = complex_mul(x_3, twiddle_3);
73+
74+
float2 minus_i;
75+
minus_i.x = 0;
76+
minus_i.y = -1;
77+
78+
// Hard coded twiddle factors for DFT4
79+
float2 z_0 = x_0 + x_2;
80+
float2 z_1 = x_0 - x_2;
81+
float2 z_2 = x_1 + x_3;
82+
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
83+
84+
float2 y_0 = z_0 + z_2;
85+
float2 y_1 = z_1 + z_3;
86+
float2 y_2 = z_0 - z_2;
87+
float2 y_3 = z_1 - z_3;
88+
89+
int j = ((i - k) << 2) + k;
90+
91+
write_buf[j] = y_0;
92+
write_buf[j + p] = y_1;
93+
write_buf[j + 2*p] = y_2;
94+
write_buf[j + 3*p] = y_3;
95+
}
96+
97+
98+
// Each FFT is computed entirely in shared GPU memory.
99+
//
100+
// N is decomposed into radix-2 and radix-4 DFTs:
101+
// e.g. 128 = 2 * 4 * 4 * 4
102+
//
103+
// At each step we use n / 4 threads, each performing
104+
// a single-threaded radix-4 or radix-2 DFT.
105+
//
106+
// We provide the number of radix-2 and radix-4
107+
// steps at compile time for a ~20% performance boost.
108+
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
109+
[[kernel]] void fft(
110+
const device float2 *in [[buffer(0)]],
111+
device float2 * out [[buffer(1)]],
112+
uint3 thread_position_in_grid [[thread_position_in_grid]],
113+
uint3 threads_per_grid [[threads_per_grid]]) {
114+
115+
// Index of the DFT in batch
116+
int batch_idx = thread_position_in_grid.x * n;
117+
// The index in the DFT we're working on
118+
int i = thread_position_in_grid.y;
119+
// The number of the threads we're using for each DFT
120+
int m = threads_per_grid.y;
121+
122+
// Allocate 2 shared memory buffers for Stockham.
123+
// We alternate reading from one and writing to the other at each radix step.
124+
threadgroup float2 shared_in[n];
125+
threadgroup float2 shared_out[n];
126+
127+
// Pointers to facilitate Stockham buffer swapping
128+
threadgroup float2* read_buf = shared_in;
129+
threadgroup float2* write_buf = shared_out;
130+
threadgroup float2* tmp;
131+
132+
// Copy input into shared memory
133+
shared_in[i] = in[batch_idx + i];
134+
shared_in[i + m] = in[batch_idx + i + m];
135+
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
136+
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
137+
138+
threadgroup_barrier(mem_flags::mem_threadgroup);
139+
140+
int p = 1;
141+
142+
for (size_t r = 0; r < radix_2_steps; r++) {
143+
radix2(i, p, m*2, read_buf, write_buf);
144+
radix2(i + m, p, m*2, read_buf, write_buf);
145+
p *= 2;
146+
147+
threadgroup_barrier(mem_flags::mem_threadgroup);
148+
149+
// Stockham switch of buffers
150+
tmp = write_buf;
151+
write_buf = read_buf;
152+
read_buf = tmp;
153+
}
154+
155+
for (size_t r = 0; r < radix_4_steps; r++) {
156+
radix4(i, p, m, read_buf, write_buf);
157+
p *= 4;
158+
159+
threadgroup_barrier(mem_flags::mem_threadgroup);
160+
161+
// Stockham switch of buffers
162+
tmp = write_buf;
163+
write_buf = read_buf;
164+
read_buf = tmp;
165+
}
166+
167+
// Copy shared memory to output
168+
out[batch_idx + i] = read_buf[i];
169+
out[batch_idx + i + m] = read_buf[i + m];
170+
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
171+
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
172+
}
173+
174+
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
175+
template [[host_name("fft_" #name)]] \
176+
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
177+
const device float2* in [[buffer(0)]], \
178+
device float2* out [[buffer(1)]], \
179+
uint3 thread_position_in_grid [[thread_position_in_grid]], \
180+
uint3 threads_per_grid [[threads_per_grid]]);
181+
182+
183+
// Explicitly define kernels for each power of 2.
184+
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
185+
instantiate_fft(8, 8, 1, 1)
186+
instantiate_fft(16, 16, 0, 2)
187+
instantiate_fft(32, 32, 1, 2)
188+
instantiate_fft(64, 64, 0, 3)
189+
instantiate_fft(128, 128, 1, 3)
190+
instantiate_fft(256, 256, 0, 4)
191+
instantiate_fft(512, 512, 1, 4)
192+
instantiate_fft(1024, 1024, 0, 5)
193+
// 2048 is the max that will fit into 32KB of threadgroup memory.
194+
// TODO: implement 4 step FFT for larger n.
195+
instantiate_fft(2048, 2048, 1, 5)

mlx/backend/metal/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ inline void debug_set_primitive_buffer_label(
130130
#endif
131131
}
132132

133+
bool is_power_of_2(int n) {
134+
return ((n & (n - 1)) == 0) && n != 0;
135+
}
136+
133137
} // namespace
134138

135139
} // namespace mlx::core

0 commit comments

Comments
 (0)