|
8 | 8 | add, |
9 | 9 | cos, |
10 | 10 | eq, |
| 11 | + erfc, |
| 12 | + erfcx, |
11 | 13 | exp, |
12 | 14 | ge, |
13 | 15 | gt, |
14 | 16 | int_div, |
15 | 17 | isinf, |
| 18 | + isnan, |
16 | 19 | le, |
17 | 20 | log, |
18 | 21 | lt, |
|
22 | 25 | prod, |
23 | 26 | sigmoid, |
24 | 27 | sin, |
| 28 | + softplus, |
25 | 29 | sub, |
26 | 30 | true_div, |
27 | 31 | ) |
@@ -189,3 +193,126 @@ def test_elemwise_two_inputs(op) -> None: |
189 | 193 | x_test = mx.array([1.0, 2.0, 3.0]) |
190 | 194 | y_test = mx.array([4.0, 5.0, 6.0]) |
191 | 195 | compare_mlx_and_py([x, y], out, [x_test, y_test]) |
| 196 | + |
| 197 | + |
| 198 | +def test_switch() -> None: |
| 199 | + x = vector("x") |
| 200 | + y = vector("y") |
| 201 | + |
| 202 | + out = switch(x > 0, y, x) |
| 203 | + |
| 204 | + x_test = mx.array([-1.0, 2.0, 3.0]) |
| 205 | + y_test = mx.array([4.0, 5.0, 6.0]) |
| 206 | + |
| 207 | + compare_mlx_and_py([x, y], out, [x_test, y_test]) |
| 208 | + |
| 209 | + |
| 210 | +def test_int_div_specific() -> None: |
| 211 | + x = vector("x") |
| 212 | + y = vector("y") |
| 213 | + out = int_div(x, y) |
| 214 | + |
| 215 | + # Test with integers that demonstrate floor division behavior |
| 216 | + x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0]) |
| 217 | + y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0]) |
| 218 | + |
| 219 | + compare_mlx_and_py([x, y], out, [x_test, y_test]) |
| 220 | + |
| 221 | + |
| 222 | +def test_isnan() -> None: |
| 223 | + x = vector("x") |
| 224 | + out = isnan(x) |
| 225 | + |
| 226 | + x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf]) |
| 227 | + |
| 228 | + compare_mlx_and_py([x], out, [x_test]) |
| 229 | + |
| 230 | + |
| 231 | +def test_isnan_edge_cases() -> None: |
| 232 | + from pytensor.tensor.type import scalar |
| 233 | + |
| 234 | + x = scalar("x") |
| 235 | + out = isnan(x) |
| 236 | + |
| 237 | + # Test individual cases |
| 238 | + test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10] |
| 239 | + |
| 240 | + for test_val in test_cases: |
| 241 | + x_test = test_val |
| 242 | + compare_mlx_and_py([x], out, [x_test]) |
| 243 | + |
| 244 | + |
| 245 | +def test_erfc() -> None: |
| 246 | + """Test complementary error function""" |
| 247 | + x = vector("x") |
| 248 | + out = erfc(x) |
| 249 | + |
| 250 | + # Test with various values including negative, positive, and zero |
| 251 | + x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1]) |
| 252 | + |
| 253 | + compare_mlx_and_py([x], out, [x_test]) |
| 254 | + |
| 255 | + |
| 256 | +def test_erfc_extreme_values() -> None: |
| 257 | + """Test erfc with extreme values""" |
| 258 | + from functools import partial |
| 259 | + |
| 260 | + x = vector("x") |
| 261 | + out = erfc(x) |
| 262 | + |
| 263 | + # Test with larger values where erfc approaches 0 or 2 |
| 264 | + x_test = mx.array([-3.0, -2.5, 2.5, 3.0]) |
| 265 | + |
| 266 | + # Use relaxed tolerance for extreme values due to numerical precision differences |
| 267 | + relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6) |
| 268 | + |
| 269 | + compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert) |
| 270 | + |
| 271 | + |
| 272 | +def test_erfcx() -> None: |
| 273 | + """Test scaled complementary error function""" |
| 274 | + x = vector("x") |
| 275 | + out = erfcx(x) |
| 276 | + |
| 277 | + # Test with positive values where erfcx is most numerically stable |
| 278 | + x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5]) |
| 279 | + |
| 280 | + compare_mlx_and_py([x], out, [x_test]) |
| 281 | + |
| 282 | + |
| 283 | +def test_erfcx_small_values() -> None: |
| 284 | + """Test erfcx with small values""" |
| 285 | + x = vector("x") |
| 286 | + out = erfcx(x) |
| 287 | + |
| 288 | + # Test with small values |
| 289 | + x_test = mx.array([0.001, 0.01, 0.1, 0.2]) |
| 290 | + |
| 291 | + compare_mlx_and_py([x], out, [x_test]) |
| 292 | + |
| 293 | + |
| 294 | +def test_softplus() -> None: |
| 295 | + """Test softplus (log(1 + exp(x))) function""" |
| 296 | + x = vector("x") |
| 297 | + out = softplus(x) |
| 298 | + |
| 299 | + # Test with normal range values |
| 300 | + x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0]) |
| 301 | + |
| 302 | + compare_mlx_and_py([x], out, [x_test]) |
| 303 | + |
| 304 | + |
| 305 | +def test_softplus_extreme_values() -> None: |
| 306 | + """Test softplus with extreme values to verify numerical stability""" |
| 307 | + from functools import partial |
| 308 | + |
| 309 | + x = vector("x") |
| 310 | + out = softplus(x) |
| 311 | + |
| 312 | + # Test with extreme values where different branches of the implementation are used |
| 313 | + x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0]) |
| 314 | + |
| 315 | + # Use relaxed tolerance for extreme values due to numerical precision differences |
| 316 | + relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8) |
| 317 | + |
| 318 | + compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert) |
0 commit comments