Skip to content

Commit f3d1e29

Browse files
committed
TST: Add comprehensive testing for data.dmri.io.to_nifti
Add comprehensive testing for `data.dmri.io.to_nifti`. Specifically, parametrize the test over the function parameters and provide a range of values to exhaustively check that its behavior matches the expected one.
1 parent 301429b commit f3d1e29

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

test/test_data_dmri.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
"""Unit tests exercising the dMRI data structure."""
2424

2525
import re
26+
import warnings
2627
from pathlib import Path
2728

2829
import attrs
2930
import nibabel as nb
3031
import numpy as np
3132
import pytest
33+
from dipy.core.geometry import normalized_vector
3234

3335
from nifreeze.data import load
3436
from nifreeze.data.dmri.base import (
@@ -39,6 +41,7 @@
3941
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG,
4042
GRADIENT_DATA_MISSING_ERROR,
4143
from_nii,
44+
to_nifti,
4245
)
4346
from nifreeze.data.dmri.utils import (
4447
DTI_MIN_ORIENTATIONS,
@@ -509,6 +512,164 @@ def test_load_gradients_missing(tmp_path, setup_random_dwi_data):
509512
from_nii(dwi_fname)
510513

511514

515+
@pytest.mark.parametrize("vol_size", [(4, 4, 5)])
516+
@pytest.mark.parametrize("b0_count", [0, 1])
517+
@pytest.mark.parametrize("bval_min, bval_max", [(800.0, 1200.0)])
518+
@pytest.mark.parametrize("provide_bzero", [False, True])
519+
@pytest.mark.parametrize("insert_b0", [False, True])
520+
@pytest.mark.parametrize("motion_affines", [None, 2 * np.eye(4)])
521+
@pytest.mark.parametrize("bvals_dec_places, bvecs_dec_places", [(2, 6), (1, 4)])
522+
@pytest.mark.parametrize("file_basename", [None, "dwi.nii.gz"])
523+
def test_to_nifti(
524+
request,
525+
tmp_path,
526+
monkeypatch,
527+
vol_size,
528+
b0_count,
529+
bval_min,
530+
bval_max,
531+
provide_bzero,
532+
insert_b0,
533+
motion_affines,
534+
bvals_dec_places,
535+
bvecs_dec_places,
536+
file_basename,
537+
):
538+
rng = request.node.rng
539+
540+
# Choose n_vols safely above the minimum DTI orientations
541+
n_vols = max(10, DTI_MIN_ORIENTATIONS + 2)
542+
543+
# Build b-values array: first b0_count are zeros
544+
non_b0_count = n_vols - b0_count
545+
# Sample non-b0 bvals between min and max values
546+
rest_bvals = rng.uniform(bval_min, bval_max, size=non_b0_count)
547+
bvals = np.concatenate((np.zeros(b0_count), rest_bvals)).astype(int)
548+
549+
# Create bvecs and assemble gradients
550+
bzeros = np.zeros((b0_count, 3))
551+
bvecs = normalized_vector(rng.random((3, non_b0_count)), axis=0).T
552+
bvecs = np.vstack((bzeros, bvecs))
553+
gradients = np.column_stack((bvecs, bvals))
554+
555+
# Create random dataobj with shape
556+
dataobj = rng.standard_normal((*vol_size, n_vols)).astype(float)
557+
558+
# Optionally supply a bzero
559+
provided = None
560+
affine = np.eye(4)
561+
_motion_affines = (
562+
np.stack([motion_affines] * non_b0_count) if motion_affines is not None else None
563+
)
564+
if provide_bzero:
565+
# Use a constant map so it's easy to assert equality
566+
provided = np.full((*vol_size, max(1, b0_count)), 42.0, dtype=float).squeeze()
567+
dwi_obj = DWI(
568+
dataobj=dataobj,
569+
affine=affine,
570+
motion_affines=_motion_affines,
571+
gradients=gradients,
572+
bzero=provided,
573+
)
574+
else:
575+
dwi_obj = DWI(
576+
dataobj=dataobj, affine=affine, motion_affines=_motion_affines, gradients=gradients
577+
)
578+
579+
_filename = tmp_path / file_basename if file_basename is not None else file_basename
580+
581+
# Monkeypatch the to_nifti alias to only perform essential operations for
582+
# the purpose of this test
583+
def simple_to_nifti(_dataset, filename=None, write_hmxfms=None, order=None):
584+
_ = write_hmxfms
585+
_ = order
586+
_nii = nb.Nifti1Image(_dataset.dataobj, _dataset.affine)
587+
if filename is not None:
588+
_nii.to_filename(filename)
589+
return _nii
590+
591+
monkeypatch.setattr("nifreeze.data.dmri.io._base_to_nifti", simple_to_nifti)
592+
593+
with warnings.catch_warnings(record=True) as caught:
594+
nii = to_nifti(
595+
dwi_obj,
596+
_filename,
597+
write_hmxfms=False,
598+
order=3,
599+
insert_b0=insert_b0,
600+
bvals_dec_places=bvals_dec_places,
601+
bvecs_dec_places=bvecs_dec_places,
602+
)
603+
604+
no_bzero = dwi_obj.bzero is None or not insert_b0
605+
606+
# Check the warning
607+
if no_bzero:
608+
if insert_b0:
609+
assert (
610+
str(caught[0].message)
611+
== "Ignoring ``insert_b0`` argument as the data object's bzero field is unset"
612+
)
613+
614+
bvecs_dwi = dwi_obj.bvecs
615+
bvals_dwi = dwi_obj.bvals
616+
# Transform bvecs if motion affines are present
617+
if dwi_obj.motion_affines is not None:
618+
rotated = [
619+
transform_fsl_bvec(_bvec, _affine, dwi_obj.affine, invert=True)
620+
for _bvec, _affine in zip(bvecs_dwi, dwi_obj.motion_affines, strict=True)
621+
]
622+
bvecs_dwi = np.asarray(rotated)
623+
624+
# Check the primary NIfTI output
625+
_dataobj = dwi_obj.dataobj
626+
# Concatenate the b0 if the primary data has a b0 volume or if it was
627+
# requested to do so
628+
if not no_bzero:
629+
assert dwi_obj.bzero is not None
630+
# ToDo
631+
# The code will concatenate as many zeros as they exist to the data
632+
_dataobj = np.concatenate((dwi_obj.bzero[..., np.newaxis], dwi_obj.dataobj), axis=-1)
633+
# But when inserting b0 data to the gradients, it inserts a single b0.
634+
# Here I will insert as many values as b0 volumes to make the test fail
635+
dwi_b0_count = dwi_obj.bzero.shape[-1] if dwi_obj.bzero.ndim == 4 else 1
636+
bvals_dwi = np.concatenate((np.zeros(dwi_b0_count), bvals_dwi))
637+
bvecs_dwi = np.vstack((np.zeros((dwi_b0_count, bvecs_dwi.shape[1])), bvecs_dwi))
638+
639+
assert isinstance(nii, nb.Nifti1Image)
640+
assert np.allclose(nii.get_fdata(), _dataobj)
641+
assert np.allclose(nii.affine, dwi_obj.affine)
642+
643+
# Check the written files, if any
644+
if _filename is None:
645+
assert not any(tmp_path.iterdir()), "Directory is not empty"
646+
else:
647+
# Check the written NIfTI file
648+
assert _filename.is_file()
649+
650+
_nii_load = load_api(_filename, nb.Nifti1Image)
651+
652+
# Build a NIfTI file with the data object that potentially contains
653+
# concatenated b0 data
654+
_nii_dataobj = nb.Nifti1Image(_dataobj, nii.affine, nii.header)
655+
656+
np.allclose(_nii_dataobj.get_fdata(), _nii_load.get_fdata())
657+
np.allclose(_nii_dataobj.affine, _nii_load.affine)
658+
659+
# Check gradients
660+
bvecs_file = _filename.with_suffix("").with_suffix(".bvec")
661+
bvals_file = _filename.with_suffix("").with_suffix(".bval")
662+
assert bvals_file.is_file()
663+
assert bvecs_file.is_file()
664+
665+
# Read the files
666+
bvals_from_file = np.loadtxt(bvals_file)
667+
bvecs_from_file = np.loadtxt(bvecs_file).T
668+
669+
assert np.allclose(bvals_from_file, bvals_dwi, rtol=0, atol=10**-bvals_dec_places)
670+
assert np.allclose(bvecs_from_file, bvecs_dwi, rtol=0, atol=10**-bvecs_dec_places)
671+
672+
512673
@pytest.mark.skip(reason="to_nifti takes absurdly long")
513674
@pytest.mark.parametrize("insert_b0", (False, True))
514675
@pytest.mark.parametrize("rotate_bvecs", (False, True))

0 commit comments

Comments
 (0)