2929import nibabel as nb
3030import numpy as np
3131import pytest
32+ from dipy .core .geometry import normalized_vector
3233
3334from nifreeze .data import load
3435from nifreeze .data .dmri .base import (
4142 from_nii ,
4243)
4344from nifreeze .data .dmri .utils import (
45+ DEFAULT_LOWB_THRESHOLD ,
4446 DTI_MIN_ORIENTATIONS ,
4547 GRADIENT_ABSENCE_ERROR_MSG ,
4648 GRADIENT_EXPECTED_COLUMNS_ERROR_MSG ,
@@ -176,8 +178,20 @@ def test_format_gradients_basic(value, expect_transpose):
176178 assert np .allclose (obtained , np .asarray (value ))
177179
178180
179- @pytest .mark .random_uniform_spatial_data ((2 , 2 , 2 , 4 ), 0.0 , 1.0 )
180- def test_dwi_post_init_errors (setup_random_uniform_spatial_data ):
181+ @pytest .mark .parametrize (
182+ "case_mark" ,
183+ [
184+ pytest .param (
185+ None ,
186+ marks = pytest .mark .random_uniform_spatial_data ((2 , 2 , 2 , 6 ), 0.0 , 1.0 ),
187+ ),
188+ pytest .param (
189+ None ,
190+ marks = pytest .mark .random_uniform_spatial_data ((2 , 2 , 2 , 4 ), 0.0 , 1.0 ),
191+ ),
192+ ],
193+ )
194+ def test_dwi_post_init_errors (setup_random_uniform_spatial_data , case_mark ):
181195 data , affine = setup_random_uniform_spatial_data
182196 with pytest .raises (ValueError , match = GRADIENT_ABSENCE_ERROR_MSG ):
183197 DWI (dataobj = data , affine = affine )
@@ -189,6 +203,89 @@ def test_dwi_post_init_errors(setup_random_uniform_spatial_data):
189203 DWI (dataobj = data , affine = affine , gradients = B_MATRIX [: data .shape [- 1 ], :])
190204
191205
206+ @pytest .mark .parametrize ("vol_size" , [(11 , 11 , 7 )])
207+ @pytest .mark .parametrize ("b0_count" , [0 , 1 ])
208+ @pytest .mark .parametrize ("bval_min, bval_max" , [(800.0 , 1200.0 )])
209+ @pytest .mark .parametrize ("provide_bzero" , [False , True ])
210+ def test_dwi_post_init_b0_handling (request , vol_size , b0_count , bval_min , bval_max , provide_bzero ):
211+ """Check b0 handling when instantiating the DWI class.
212+
213+ For each parameter combination:
214+ - Build a gradient table whose first `b0_count` volumes have b=0
215+ and the rest have b-values in the range (bvalmin, bvalmax);
216+ - Build a random dataobj of shape (**vol_size, N) where N is the number
217+ of DWI volumes;
218+ - If `provide_bzero` is True, pass explicit bzero data that must be
219+ preserved; else, rely on the bzero computed at instantiation, i.e.
220+ if a single bzero is provided, set the attribute to that value; if there
221+ are multiple bzeros, set the attribute to the median value.
222+ """
223+ rng = request .node .rng
224+
225+ # Choose n_vols safely above the minimum DTI orientations
226+ n_vols = max (10 , DTI_MIN_ORIENTATIONS + 2 )
227+
228+ # Build b-values array: first b0_count are zeros
229+ non_b0_count = n_vols - b0_count
230+ # Sample non-b0 bvals between min and max values
231+ rest_bvals = rng .uniform (bval_min , bval_max , size = non_b0_count )
232+ bvals = np .concatenate ((np .zeros (b0_count ), rest_bvals )).astype (int )
233+
234+ # Create bvecs and assemble gradients
235+ bzeros = np .zeros ((b0_count , 3 ))
236+ bvecs = normalized_vector (rng .random ((3 , non_b0_count )), axis = 0 ).T
237+ bvecs = np .vstack ((bzeros , bvecs ))
238+ gradients = np .column_stack ((bvecs , bvals ))
239+
240+ # Create random dataobj with shape
241+ dataobj = rng .standard_normal ((* vol_size , n_vols )).astype (float )
242+
243+ # Optionally supply a bzero
244+ provided = None
245+ affine = np .eye (4 )
246+ if provide_bzero :
247+ # Use a constant map so it's easy to assert equality
248+ provided = np .full ((* vol_size , max (1 , b0_count )), 42.0 , dtype = float ).squeeze ()
249+ dwi_obj = DWI (dataobj = dataobj , affine = affine , gradients = gradients , bzero = provided )
250+ else :
251+ dwi_obj = DWI (dataobj = dataobj , affine = affine , gradients = gradients )
252+
253+ # Count expected b0 frames according to the same threshold used by the code
254+ b0_mask = bvals <= DEFAULT_LOWB_THRESHOLD
255+ expected_b0_num = int (np .sum (b0_mask ))
256+ # In all cases where b0 frames existed (whether provided externally or not),
257+ # they should have been removed from the DWI object's internal gradients and
258+ # dataobj arrays
259+ expected_non_b0_count = n_vols - expected_b0_num
260+
261+ # If no b0 frames expected, bzero should be None (unless user provided one)
262+ if expected_b0_num == 0 and not provide_bzero :
263+ assert dwi_obj .bzero is None , (
264+ "Expected bzero to be None when no low-b frames and no provided bzero"
265+ )
266+ else :
267+ assert dwi_obj .bzero is not None
268+ # If provided_bzero is True, it must be preserved exactly
269+ if provide_bzero :
270+ assert provided is not None
271+ assert np .allclose (dwi_obj .bzero , provided )
272+ else :
273+ # When there are b0 frames and no provided bzero:
274+ # - If exactly one b0 frame, the stored bzero should be the 3D volume
275+ # - If multiple b0 frames, the stored bzero should be the median along last axis
276+ b0_vols = dataobj [
277+ ..., b0_mask
278+ ].squeeze () # shape (X,Y,Z,expected_b0_num) or (X,Y,Z) if 1
279+ expected_bzero = b0_vols if b0_vols .ndim == 3 else np .median (b0_vols , axis = - 1 )
280+ assert np .allclose (dwi_obj .bzero , expected_bzero )
281+
282+ assert dwi_obj .gradients .shape [0 ] == expected_non_b0_count
283+ assert dwi_obj .dataobj .shape [- 1 ] == expected_non_b0_count
284+
285+ assert np .allclose (dwi_obj .gradients , gradients [~ b0_mask ])
286+ assert np .allclose (dwi_obj .dataobj , dataobj [..., ~ b0_mask ])
287+
288+
192289@pytest .mark .random_gtab_data (10 , (1000 , 2000 ), 2 )
193290@pytest .mark .random_dwi_data (50 , (34 , 36 , 24 ), True )
194291@pytest .mark .parametrize ("row_major_gradients" , (False , True ))
0 commit comments