Skip to content

Commit bdf96ac

Browse files
authored
Merge pull request #360 from jhlegarreta/tst/test-base-data-to-nifti
TST: Add tests to check the base data module `to_nifti` behavior
2 parents 4d4662f + b27b092 commit bdf96ac

File tree

1 file changed

+217
-2
lines changed

1 file changed

+217
-2
lines changed

test/test_data_base.py

Lines changed: 217 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
"""Test dataset base class."""
2424

25+
import itertools
2526
import re
2627
from pathlib import Path
2728
from tempfile import TemporaryDirectory
@@ -43,13 +44,94 @@
4344
DATAOBJ_OBJECT_ERROR_MSG,
4445
_has_dim_size,
4546
_has_ndim,
47+
to_nifti,
4648
)
47-
from nifreeze.utils.ndimage import get_data
49+
from nifreeze.utils.ndimage import get_data, load_api
4850

4951
DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5)
5052
DEFAULT_RANDOM_DATASET_SIZE = int(np.prod(DEFAULT_RANDOM_DATASET_SHAPE[:3]))
5153

5254

55+
# Dummy transform classes and functions to monkeypatch into the real module
56+
class DummyTransform:
57+
def __init__(self, idx):
58+
self.idx = idx
59+
60+
def to_filename(self, path):
61+
# Create a tiny marker file so tests can check it was written and its
62+
# contents
63+
path = Path(path)
64+
path.parent.mkdir(parents=True, exist_ok=True)
65+
path.write_text(f"transform-{self.idx}")
66+
67+
68+
class DummyLinearTransformsMapping:
69+
"""A class that mimics the iterable mapping of linear transforms.
70+
71+
Yields DummyTransform instances when iterated. Also supports len()
72+
and indexing to be interchangeable with sequence-like expectations.
73+
"""
74+
75+
def __init__(self, transforms, reference=None):
76+
# Determine number of transforms in an explicit, non-broad-except way.
77+
if transforms is None:
78+
n = 0
79+
elif hasattr(transforms, "__len__"):
80+
# len() may raise TypeError for objects that do not support it
81+
try:
82+
n = len(transforms)
83+
except TypeError:
84+
# Treat non-sized objects as zero-length for this test helper
85+
n = 0
86+
else:
87+
# Explicitly raise for unexpected types to fail fast and be explicit
88+
raise TypeError("transforms must be a sequence or None")
89+
90+
# Return one DummyTransform per motion_affine
91+
self._xforms = [DummyTransform(i) for i in range(n)]
92+
93+
def __iter__(self):
94+
return iter(self._xforms)
95+
96+
def __len__(self):
97+
return len(self._xforms)
98+
99+
def __getitem__(self, idx):
100+
return self._xforms[idx]
101+
102+
103+
def dummy_apply(transforms, spatialimage, order=3):
104+
"""A deterministic 'resampling' that modifies the data so tests can verify
105+
that apply() and the transforms mapping were actually used.
106+
107+
It returns a new Nifti1Image whose data is the original frame plus
108+
the transform index (transforms.idx). This makes each frame distinct and
109+
easily predictable.
110+
"""
111+
data = np.asanyarray(spatialimage.dataobj).copy()
112+
# Mutate in a simple, deterministic way that depends on transform index
113+
data = data + int(getattr(transforms, "idx", 0))
114+
return nb.Nifti1Image(data, spatialimage.affine, spatialimage.header)
115+
116+
117+
class DummyImageGrid:
118+
def __init__(self, shape, affine):
119+
self.shape = shape
120+
self.affine = affine
121+
122+
123+
class DummyDataset(BaseDataset):
124+
def __init__(self, shape, datahdr=None, motion_affines=None, dtype=np.int16):
125+
self.dataobj = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
126+
self.affine = np.eye(4)
127+
self.datahdr = datahdr
128+
self.motion_affines = motion_affines
129+
130+
def __getitem__(self, idx):
131+
# to_nifti expects dataset[idx] to return a sequence whose first item is the 3D frame
132+
return (self.dataobj[..., idx],)
133+
134+
53135
@pytest.mark.parametrize(
54136
"setup_random_uniform_spatial_data",
55137
[
@@ -242,7 +324,7 @@ def test_to_filename_and_from_filename(random_dataset: BaseDataset):
242324
assert np.allclose(random_dataset.dataobj, ds2.dataobj)
243325

244326

245-
def test_to_nifti(random_dataset: BaseDataset):
327+
def test_object_to_nifti(random_dataset: BaseDataset):
246328
"""Test writing a dataset to a NIfTI file."""
247329
with TemporaryDirectory() as tmpdir:
248330
nifti_file = Path(tmpdir) / "test_dataset.nii.gz"
@@ -258,6 +340,139 @@ def test_to_nifti(random_dataset: BaseDataset):
258340
assert np.allclose(data, random_dataset.dataobj)
259341

260342

343+
import warnings
344+
345+
346+
@pytest.mark.parametrize(
347+
"filename_is_none, motion_affines_present, write_hmxfms, expected_message",
348+
[
349+
# write_hmxfms True but no filename
350+
(True, True, True, "write_hmxfms is set to True, but no filename was provided."),
351+
# write_hmxfms True, filename given, but no motion affines
352+
(
353+
False,
354+
False,
355+
True,
356+
"write_hmxfms is set to True, but no motion affines were found. Skipping.",
357+
),
358+
],
359+
)
360+
def test_to_nifti_warnings(
361+
tmp_path, monkeypatch, filename_is_none, motion_affines_present, write_hmxfms, expected_message
362+
):
363+
# Monkeypatch the helpers in the module where to_nifti is defined
364+
import nifreeze.data.base as base_mod
365+
366+
monkeypatch.setattr(base_mod, "LinearTransformsMapping", DummyLinearTransformsMapping)
367+
monkeypatch.setattr(base_mod, "apply", dummy_apply)
368+
monkeypatch.setattr(base_mod, "ImageGrid", DummyImageGrid)
369+
370+
n_frames = 3
371+
shape = (4, 4, 2, n_frames)
372+
373+
motion_affines = [np.eye(4) for _ in range(n_frames)] if motion_affines_present else None
374+
375+
dataset = DummyDataset(shape, datahdr=None, motion_affines=motion_affines)
376+
377+
filename = None
378+
if not filename_is_none:
379+
filename = tmp_path / "data.nii.gz"
380+
381+
with warnings.catch_warnings(record=True) as w:
382+
warnings.simplefilter("always")
383+
_ = to_nifti(dataset, filename=filename, write_hmxfms=write_hmxfms, order=1)
384+
385+
assert any(expected_message in str(x.message) for x in w)
386+
387+
388+
@pytest.mark.parametrize(
389+
"filename_is_none, write_hmxfms, motion_affines_present, datahdr_present",
390+
list(itertools.product([True, False], repeat=4)),
391+
)
392+
def test_to_nifti(
393+
tmp_path, monkeypatch, filename_is_none, write_hmxfms, motion_affines_present, datahdr_present
394+
):
395+
# Monkeypatch the helpers in the module where to_nifti is defined
396+
import nifreeze.data.base as base_mod
397+
398+
monkeypatch.setattr(base_mod, "LinearTransformsMapping", DummyLinearTransformsMapping)
399+
monkeypatch.setattr(base_mod, "apply", dummy_apply)
400+
monkeypatch.setattr(base_mod, "ImageGrid", DummyImageGrid)
401+
402+
n_frames = 3
403+
shape = (4, 4, 2, n_frames)
404+
dtype = np.int16
405+
406+
datahdr = None
407+
if datahdr_present:
408+
hdr = nb.Nifti1Header()
409+
hdr.set_data_dtype(dtype)
410+
datahdr = hdr
411+
412+
motion_affines = None
413+
if motion_affines_present:
414+
motion_affines = [np.eye(4) for _ in range(n_frames)]
415+
416+
dataset = DummyDataset(shape, datahdr=datahdr, motion_affines=motion_affines, dtype=dtype)
417+
418+
filename = None
419+
if not filename_is_none:
420+
filename = tmp_path / "data.nii.gz"
421+
422+
# Suppress warnings in this test
423+
with warnings.catch_warnings():
424+
warnings.simplefilter("ignore")
425+
nii = to_nifti(dataset, filename=filename, write_hmxfms=write_hmxfms, order=1)
426+
427+
# Check returned data
428+
assert isinstance(nii, nb.Nifti1Image)
429+
assert nii.shape == dataset.dataobj.shape
430+
431+
expected = dataset.dataobj.copy()
432+
if motion_affines_present:
433+
# If motion affines are present, fake_apply adds the frame index to each
434+
# frame
435+
for i in range(n_frames):
436+
expected[..., i] = expected[..., i] + i
437+
assert np.array_equal(nii.dataobj, expected), (
438+
"Resampled data should reflect fake_apply modifications"
439+
)
440+
else:
441+
# No resampling; data should be identical to original
442+
assert np.array_equal(nii.dataobj, dataset.dataobj)
443+
444+
# Header behavior
445+
if datahdr_present:
446+
assert datahdr is not None
447+
assert nii.header.get_data_dtype() == datahdr.get_data_dtype()
448+
else:
449+
xyzt = nii.header.get_xyzt_units()
450+
assert xyzt[0].lower() == "mm"
451+
452+
# If filename was provided, file should exist and equal to transformed data
453+
if filename is not None:
454+
assert filename.is_file()
455+
nii_load = load_api(filename, nb.Nifti1Image)
456+
assert np.array_equal(nii_load.get_fdata(), expected)
457+
else:
458+
assert not any(tmp_path.iterdir()), "Directory is not empty"
459+
460+
# When motion_affines present and write_hmxfms True and filename provided,
461+
# x5 files should be written
462+
if motion_affines_present and write_hmxfms and filename is not None:
463+
# The same file is written at every iteration, so earlier transforms are
464+
# overwritten and only the last transform remains on disk
465+
found_x5 = list(tmp_path.glob("*.x5"))
466+
assert len(found_x5) == 1
467+
x5_path = filename.with_suffix("").with_suffix(".x5")
468+
assert x5_path.is_file()
469+
content = x5_path.read_text()
470+
assert content == f"transform-{n_frames - 1}"
471+
else:
472+
found_x5 = list(tmp_path.glob("*.x5"))
473+
assert len(found_x5) == 0
474+
475+
261476
def test_load_hdf5(random_dataset: BaseDataset):
262477
"""Test the 'load' function with an HDF5 file."""
263478
with TemporaryDirectory() as tmpdir:

0 commit comments

Comments
 (0)