Skip to content

Commit ec545eb

Browse files
authored
Merge pull request #364 from jhlegarreta/tst/test-handling-b0-dwi-post-init
TST: Test comprehensively dMRI b0 volume handling
2 parents 301429b + 6829199 commit ec545eb

File tree

1 file changed

+99
-2
lines changed

1 file changed

+99
-2
lines changed

test/test_data_dmri.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import nibabel as nb
3030
import numpy as np
3131
import pytest
32+
from dipy.core.geometry import normalized_vector
3233

3334
from nifreeze.data import load
3435
from nifreeze.data.dmri.base import (
@@ -41,6 +42,7 @@
4142
from_nii,
4243
)
4344
from nifreeze.data.dmri.utils import (
45+
DEFAULT_LOWB_THRESHOLD,
4446
DTI_MIN_ORIENTATIONS,
4547
GRADIENT_ABSENCE_ERROR_MSG,
4648
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG,
@@ -176,8 +178,20 @@ def test_format_gradients_basic(value, expect_transpose):
176178
assert np.allclose(obtained, np.asarray(value))
177179

178180

179-
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
180-
def test_dwi_post_init_errors(setup_random_uniform_spatial_data):
181+
@pytest.mark.parametrize(
182+
"case_mark",
183+
[
184+
pytest.param(
185+
None,
186+
marks=pytest.mark.random_uniform_spatial_data((2, 2, 2, 6), 0.0, 1.0),
187+
),
188+
pytest.param(
189+
None,
190+
marks=pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0),
191+
),
192+
],
193+
)
194+
def test_dwi_post_init_errors(setup_random_uniform_spatial_data, case_mark):
181195
data, affine = setup_random_uniform_spatial_data
182196
with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG):
183197
DWI(dataobj=data, affine=affine)
@@ -189,6 +203,89 @@ def test_dwi_post_init_errors(setup_random_uniform_spatial_data):
189203
DWI(dataobj=data, affine=affine, gradients=B_MATRIX[: data.shape[-1], :])
190204

191205

206+
@pytest.mark.parametrize("vol_size", [(11, 11, 7)])
207+
@pytest.mark.parametrize("b0_count", [0, 1])
208+
@pytest.mark.parametrize("bval_min, bval_max", [(800.0, 1200.0)])
209+
@pytest.mark.parametrize("provide_bzero", [False, True])
210+
def test_dwi_post_init_b0_handling(request, vol_size, b0_count, bval_min, bval_max, provide_bzero):
211+
"""Check b0 handling when instantiating the DWI class.
212+
213+
For each parameter combination:
214+
- Build a gradient table whose first `b0_count` volumes have b=0
215+
and the rest have b-values in the range (bvalmin, bvalmax);
216+
- Build a random dataobj of shape (**vol_size, N) where N is the number
217+
of DWI volumes;
218+
- If `provide_bzero` is True, pass explicit bzero data that must be
219+
preserved; else, rely on the bzero computed at instantiation, i.e.
220+
if a single bzero is provided, set the attribute to that value; if there
221+
are multiple bzeros, set the attribute to the median value.
222+
"""
223+
rng = request.node.rng
224+
225+
# Choose n_vols safely above the minimum DTI orientations
226+
n_vols = max(10, DTI_MIN_ORIENTATIONS + 2)
227+
228+
# Build b-values array: first b0_count are zeros
229+
non_b0_count = n_vols - b0_count
230+
# Sample non-b0 bvals between min and max values
231+
rest_bvals = rng.uniform(bval_min, bval_max, size=non_b0_count)
232+
bvals = np.concatenate((np.zeros(b0_count), rest_bvals)).astype(int)
233+
234+
# Create bvecs and assemble gradients
235+
bzeros = np.zeros((b0_count, 3))
236+
bvecs = normalized_vector(rng.random((3, non_b0_count)), axis=0).T
237+
bvecs = np.vstack((bzeros, bvecs))
238+
gradients = np.column_stack((bvecs, bvals))
239+
240+
# Create random dataobj with shape
241+
dataobj = rng.standard_normal((*vol_size, n_vols)).astype(float)
242+
243+
# Optionally supply a bzero
244+
provided = None
245+
affine = np.eye(4)
246+
if provide_bzero:
247+
# Use a constant map so it's easy to assert equality
248+
provided = np.full((*vol_size, max(1, b0_count)), 42.0, dtype=float).squeeze()
249+
dwi_obj = DWI(dataobj=dataobj, affine=affine, gradients=gradients, bzero=provided)
250+
else:
251+
dwi_obj = DWI(dataobj=dataobj, affine=affine, gradients=gradients)
252+
253+
# Count expected b0 frames according to the same threshold used by the code
254+
b0_mask = bvals <= DEFAULT_LOWB_THRESHOLD
255+
expected_b0_num = int(np.sum(b0_mask))
256+
# In all cases where b0 frames existed (whether provided externally or not),
257+
# they should have been removed from the DWI object's internal gradients and
258+
# dataobj arrays
259+
expected_non_b0_count = n_vols - expected_b0_num
260+
261+
# If no b0 frames expected, bzero should be None (unless user provided one)
262+
if expected_b0_num == 0 and not provide_bzero:
263+
assert dwi_obj.bzero is None, (
264+
"Expected bzero to be None when no low-b frames and no provided bzero"
265+
)
266+
else:
267+
assert dwi_obj.bzero is not None
268+
# If provided_bzero is True, it must be preserved exactly
269+
if provide_bzero:
270+
assert provided is not None
271+
assert np.allclose(dwi_obj.bzero, provided)
272+
else:
273+
# When there are b0 frames and no provided bzero:
274+
# - If exactly one b0 frame, the stored bzero should be the 3D volume
275+
# - If multiple b0 frames, the stored bzero should be the median along last axis
276+
b0_vols = dataobj[
277+
..., b0_mask
278+
].squeeze() # shape (X,Y,Z,expected_b0_num) or (X,Y,Z) if 1
279+
expected_bzero = b0_vols if b0_vols.ndim == 3 else np.median(b0_vols, axis=-1)
280+
assert np.allclose(dwi_obj.bzero, expected_bzero)
281+
282+
assert dwi_obj.gradients.shape[0] == expected_non_b0_count
283+
assert dwi_obj.dataobj.shape[-1] == expected_non_b0_count
284+
285+
assert np.allclose(dwi_obj.gradients, gradients[~b0_mask])
286+
assert np.allclose(dwi_obj.dataobj, dataobj[..., ~b0_mask])
287+
288+
192289
@pytest.mark.random_gtab_data(10, (1000, 2000), 2)
193290
@pytest.mark.random_dwi_data(50, (34, 36, 24), True)
194291
@pytest.mark.parametrize("row_major_gradients", (False, True))

0 commit comments

Comments
 (0)