Skip to content

Commit 95c2a74

Browse files
authored
Move MLX tests from math to elemwise (#1748)
1 parent 0758de4 commit 95c2a74

File tree

2 files changed

+127
-121
lines changed

2 files changed

+127
-121
lines changed

tests/link/mlx/test_elemwise.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
add,
99
cos,
1010
eq,
11+
erfc,
12+
erfcx,
1113
exp,
1214
ge,
1315
gt,
1416
int_div,
1517
isinf,
18+
isnan,
1619
le,
1720
log,
1821
lt,
@@ -22,6 +25,7 @@
2225
prod,
2326
sigmoid,
2427
sin,
28+
softplus,
2529
sub,
2630
true_div,
2731
)
@@ -189,3 +193,126 @@ def test_elemwise_two_inputs(op) -> None:
189193
x_test = mx.array([1.0, 2.0, 3.0])
190194
y_test = mx.array([4.0, 5.0, 6.0])
191195
compare_mlx_and_py([x, y], out, [x_test, y_test])
196+
197+
198+
def test_switch() -> None:
199+
x = vector("x")
200+
y = vector("y")
201+
202+
out = switch(x > 0, y, x)
203+
204+
x_test = mx.array([-1.0, 2.0, 3.0])
205+
y_test = mx.array([4.0, 5.0, 6.0])
206+
207+
compare_mlx_and_py([x, y], out, [x_test, y_test])
208+
209+
210+
def test_int_div_specific() -> None:
211+
x = vector("x")
212+
y = vector("y")
213+
out = int_div(x, y)
214+
215+
# Test with integers that demonstrate floor division behavior
216+
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
217+
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])
218+
219+
compare_mlx_and_py([x, y], out, [x_test, y_test])
220+
221+
222+
def test_isnan() -> None:
223+
x = vector("x")
224+
out = isnan(x)
225+
226+
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])
227+
228+
compare_mlx_and_py([x], out, [x_test])
229+
230+
231+
def test_isnan_edge_cases() -> None:
232+
from pytensor.tensor.type import scalar
233+
234+
x = scalar("x")
235+
out = isnan(x)
236+
237+
# Test individual cases
238+
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]
239+
240+
for test_val in test_cases:
241+
x_test = test_val
242+
compare_mlx_and_py([x], out, [x_test])
243+
244+
245+
def test_erfc() -> None:
246+
"""Test complementary error function"""
247+
x = vector("x")
248+
out = erfc(x)
249+
250+
# Test with various values including negative, positive, and zero
251+
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])
252+
253+
compare_mlx_and_py([x], out, [x_test])
254+
255+
256+
def test_erfc_extreme_values() -> None:
257+
"""Test erfc with extreme values"""
258+
from functools import partial
259+
260+
x = vector("x")
261+
out = erfc(x)
262+
263+
# Test with larger values where erfc approaches 0 or 2
264+
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])
265+
266+
# Use relaxed tolerance for extreme values due to numerical precision differences
267+
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)
268+
269+
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
270+
271+
272+
def test_erfcx() -> None:
273+
"""Test scaled complementary error function"""
274+
x = vector("x")
275+
out = erfcx(x)
276+
277+
# Test with positive values where erfcx is most numerically stable
278+
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])
279+
280+
compare_mlx_and_py([x], out, [x_test])
281+
282+
283+
def test_erfcx_small_values() -> None:
284+
"""Test erfcx with small values"""
285+
x = vector("x")
286+
out = erfcx(x)
287+
288+
# Test with small values
289+
x_test = mx.array([0.001, 0.01, 0.1, 0.2])
290+
291+
compare_mlx_and_py([x], out, [x_test])
292+
293+
294+
def test_softplus() -> None:
295+
"""Test softplus (log(1 + exp(x))) function"""
296+
x = vector("x")
297+
out = softplus(x)
298+
299+
# Test with normal range values
300+
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])
301+
302+
compare_mlx_and_py([x], out, [x_test])
303+
304+
305+
def test_softplus_extreme_values() -> None:
306+
"""Test softplus with extreme values to verify numerical stability"""
307+
from functools import partial
308+
309+
x = vector("x")
310+
out = softplus(x)
311+
312+
# Test with extreme values where different branches of the implementation are used
313+
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])
314+
315+
# Use relaxed tolerance for extreme values due to numerical precision differences
316+
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)
317+
318+
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)

tests/link/mlx/test_math.py

Lines changed: 0 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -29,127 +29,6 @@ def test_dot():
2929
np.testing.assert_allclose(actual, expected, rtol=1e-6)
3030

3131

32-
def test_switch() -> None:
33-
x = pt.vector("x")
34-
y = pt.vector("y")
35-
36-
out = pt.switch(x > 0, y, x)
37-
38-
x_test = mx.array([-1.0, 2.0, 3.0])
39-
y_test = mx.array([4.0, 5.0, 6.0])
40-
41-
compare_mlx_and_py([x, y], out, [x_test, y_test])
42-
43-
44-
def test_int_div_specific() -> None:
45-
x = pt.vector("x")
46-
y = pt.vector("y")
47-
out = pt.int_div(x, y)
48-
49-
# Test with integers that demonstrate floor division behavior
50-
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
51-
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])
52-
53-
compare_mlx_and_py([x, y], out, [x_test, y_test])
54-
55-
56-
def test_isnan() -> None:
57-
x = pt.vector("x")
58-
out = pt.isnan(x)
59-
60-
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])
61-
62-
compare_mlx_and_py([x], out, [x_test])
63-
64-
65-
def test_isnan_edge_cases() -> None:
66-
x = pt.scalar("x")
67-
out = pt.isnan(x)
68-
69-
# Test individual cases
70-
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]
71-
72-
for test_val in test_cases:
73-
x_test = test_val
74-
compare_mlx_and_py([x], out, [x_test])
75-
76-
77-
def test_erfc() -> None:
78-
"""Test complementary error function"""
79-
x = pt.vector("x")
80-
out = pt.erfc(x)
81-
82-
# Test with various values including negative, positive, and zero
83-
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])
84-
85-
compare_mlx_and_py([x], out, [x_test])
86-
87-
88-
def test_erfc_extreme_values() -> None:
89-
"""Test erfc with extreme values"""
90-
x = pt.vector("x")
91-
out = pt.erfc(x)
92-
93-
# Test with larger values where erfc approaches 0 or 2
94-
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])
95-
96-
# Use relaxed tolerance for extreme values due to numerical precision differences
97-
from functools import partial
98-
99-
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)
100-
101-
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
102-
103-
104-
def test_erfcx() -> None:
105-
"""Test scaled complementary error function"""
106-
x = pt.vector("x")
107-
out = pt.erfcx(x)
108-
109-
# Test with positive values where erfcx is most numerically stable
110-
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])
111-
112-
compare_mlx_and_py([x], out, [x_test])
113-
114-
115-
def test_erfcx_small_values() -> None:
116-
"""Test erfcx with small values"""
117-
x = pt.vector("x")
118-
out = pt.erfcx(x)
119-
120-
# Test with small values
121-
x_test = mx.array([0.001, 0.01, 0.1, 0.2])
122-
123-
compare_mlx_and_py([x], out, [x_test])
124-
125-
126-
def test_softplus() -> None:
127-
"""Test softplus (log(1 + exp(x))) function"""
128-
x = pt.vector("x")
129-
out = pt.softplus(x)
130-
131-
# Test with normal range values
132-
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])
133-
134-
compare_mlx_and_py([x], out, [x_test])
135-
136-
137-
def test_softplus_extreme_values() -> None:
138-
"""Test softplus with extreme values to verify numerical stability"""
139-
x = pt.vector("x")
140-
out = pt.softplus(x)
141-
142-
# Test with extreme values where different branches of the implementation are used
143-
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])
144-
145-
# Use relaxed tolerance for extreme values due to numerical precision differences
146-
from functools import partial
147-
148-
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)
149-
150-
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
151-
152-
15332
def test_mlx_max_and_argmax():
15433
# Test that a single output of a multi-output `Op` can be used as input to
15534
# another `Op`

0 commit comments

Comments
 (0)