2222#
2323"""Test dataset base class."""
2424
25+ import itertools
2526import re
2627from pathlib import Path
2728from tempfile import TemporaryDirectory
4344 DATAOBJ_OBJECT_ERROR_MSG ,
4445 _has_dim_size ,
4546 _has_ndim ,
47+ to_nifti ,
4648)
47- from nifreeze .utils .ndimage import get_data
49+ from nifreeze .utils .ndimage import get_data , load_api
4850
4951DEFAULT_RANDOM_DATASET_SHAPE = (32 , 32 , 32 , 5 )
5052DEFAULT_RANDOM_DATASET_SIZE = int (np .prod (DEFAULT_RANDOM_DATASET_SHAPE [:3 ]))
5153
5254
55+ # Dummy transform classes and functions to monkeypatch into the real module
56+ class DummyTransform :
57+ def __init__ (self , idx ):
58+ self .idx = idx
59+
60+ def to_filename (self , path ):
61+ # Create a tiny marker file so tests can check it was written and its
62+ # contents
63+ path = Path (path )
64+ path .parent .mkdir (parents = True , exist_ok = True )
65+ path .write_text (f"transform-{ self .idx } " )
66+
67+
68+ class DummyLinearTransformsMapping :
69+ """A class that mimics the iterable mapping of linear transforms.
70+
71+ Yields DummyTransform instances when iterated. Also supports len()
72+ and indexing to be interchangeable with sequence-like expectations.
73+ """
74+
75+ def __init__ (self , transforms , reference = None ):
76+ # Determine number of transforms in an explicit, non-broad-except way.
77+ if transforms is None :
78+ n = 0
79+ elif hasattr (transforms , "__len__" ):
80+ # len() may raise TypeError for objects that do not support it
81+ try :
82+ n = len (transforms )
83+ except TypeError :
84+ # Treat non-sized objects as zero-length for this test helper
85+ n = 0
86+ else :
87+ # Explicitly raise for unexpected types to fail fast and be explicit
88+ raise TypeError ("transforms must be a sequence or None" )
89+
90+ # Return one DummyTransform per motion_affine
91+ self ._xforms = [DummyTransform (i ) for i in range (n )]
92+
93+ def __iter__ (self ):
94+ return iter (self ._xforms )
95+
96+ def __len__ (self ):
97+ return len (self ._xforms )
98+
99+ def __getitem__ (self , idx ):
100+ return self ._xforms [idx ]
101+
102+
103+ def dummy_apply (transforms , spatialimage , order = 3 ):
104+ """A deterministic 'resampling' that modifies the data so tests can verify
105+ that apply() and the transforms mapping were actually used.
106+
107+ It returns a new Nifti1Image whose data is the original frame plus
108+ the transform index (transforms.idx). This makes each frame distinct and
109+ easily predictable.
110+ """
111+ data = np .asanyarray (spatialimage .dataobj ).copy ()
112+ # Mutate in a simple, deterministic way that depends on transform index
113+ data = data + int (getattr (transforms , "idx" , 0 ))
114+ return nb .Nifti1Image (data , spatialimage .affine , spatialimage .header )
115+
116+
117+ class DummyImageGrid :
118+ def __init__ (self , shape , affine ):
119+ self .shape = shape
120+ self .affine = affine
121+
122+
123+ class DummyDataset (BaseDataset ):
124+ def __init__ (self , shape , datahdr = None , motion_affines = None , dtype = np .int16 ):
125+ self .dataobj = np .arange (np .prod (shape ), dtype = dtype ).reshape (shape )
126+ self .affine = np .eye (4 )
127+ self .datahdr = datahdr
128+ self .motion_affines = motion_affines
129+
130+ def __getitem__ (self , idx ):
131+ # to_nifti expects dataset[idx] to return a sequence whose first item is the 3D frame
132+ return (self .dataobj [..., idx ],)
133+
134+
53135@pytest .mark .parametrize (
54136 "setup_random_uniform_spatial_data" ,
55137 [
@@ -242,7 +324,7 @@ def test_to_filename_and_from_filename(random_dataset: BaseDataset):
242324 assert np .allclose (random_dataset .dataobj , ds2 .dataobj )
243325
244326
245- def test_to_nifti (random_dataset : BaseDataset ):
327+ def test_object_to_nifti (random_dataset : BaseDataset ):
246328 """Test writing a dataset to a NIfTI file."""
247329 with TemporaryDirectory () as tmpdir :
248330 nifti_file = Path (tmpdir ) / "test_dataset.nii.gz"
@@ -258,6 +340,139 @@ def test_to_nifti(random_dataset: BaseDataset):
258340 assert np .allclose (data , random_dataset .dataobj )
259341
260342
343+ import warnings
344+
345+
346+ @pytest .mark .parametrize (
347+ "filename_is_none, motion_affines_present, write_hmxfms, expected_message" ,
348+ [
349+ # write_hmxfms True but no filename
350+ (True , True , True , "write_hmxfms is set to True, but no filename was provided." ),
351+ # write_hmxfms True, filename given, but no motion affines
352+ (
353+ False ,
354+ False ,
355+ True ,
356+ "write_hmxfms is set to True, but no motion affines were found. Skipping." ,
357+ ),
358+ ],
359+ )
360+ def test_to_nifti_warnings (
361+ tmp_path , monkeypatch , filename_is_none , motion_affines_present , write_hmxfms , expected_message
362+ ):
363+ # Monkeypatch the helpers in the module where to_nifti is defined
364+ import nifreeze .data .base as base_mod
365+
366+ monkeypatch .setattr (base_mod , "LinearTransformsMapping" , DummyLinearTransformsMapping )
367+ monkeypatch .setattr (base_mod , "apply" , dummy_apply )
368+ monkeypatch .setattr (base_mod , "ImageGrid" , DummyImageGrid )
369+
370+ n_frames = 3
371+ shape = (4 , 4 , 2 , n_frames )
372+
373+ motion_affines = [np .eye (4 ) for _ in range (n_frames )] if motion_affines_present else None
374+
375+ dataset = DummyDataset (shape , datahdr = None , motion_affines = motion_affines )
376+
377+ filename = None
378+ if not filename_is_none :
379+ filename = tmp_path / "data.nii.gz"
380+
381+ with warnings .catch_warnings (record = True ) as w :
382+ warnings .simplefilter ("always" )
383+ _ = to_nifti (dataset , filename = filename , write_hmxfms = write_hmxfms , order = 1 )
384+
385+ assert any (expected_message in str (x .message ) for x in w )
386+
387+
388+ @pytest .mark .parametrize (
389+ "filename_is_none, write_hmxfms, motion_affines_present, datahdr_present" ,
390+ list (itertools .product ([True , False ], repeat = 4 )),
391+ )
392+ def test_to_nifti (
393+ tmp_path , monkeypatch , filename_is_none , write_hmxfms , motion_affines_present , datahdr_present
394+ ):
395+ # Monkeypatch the helpers in the module where to_nifti is defined
396+ import nifreeze .data .base as base_mod
397+
398+ monkeypatch .setattr (base_mod , "LinearTransformsMapping" , DummyLinearTransformsMapping )
399+ monkeypatch .setattr (base_mod , "apply" , dummy_apply )
400+ monkeypatch .setattr (base_mod , "ImageGrid" , DummyImageGrid )
401+
402+ n_frames = 3
403+ shape = (4 , 4 , 2 , n_frames )
404+ dtype = np .int16
405+
406+ datahdr = None
407+ if datahdr_present :
408+ hdr = nb .Nifti1Header ()
409+ hdr .set_data_dtype (dtype )
410+ datahdr = hdr
411+
412+ motion_affines = None
413+ if motion_affines_present :
414+ motion_affines = [np .eye (4 ) for _ in range (n_frames )]
415+
416+ dataset = DummyDataset (shape , datahdr = datahdr , motion_affines = motion_affines , dtype = dtype )
417+
418+ filename = None
419+ if not filename_is_none :
420+ filename = tmp_path / "data.nii.gz"
421+
422+ # Suppress warnings in this test
423+ with warnings .catch_warnings ():
424+ warnings .simplefilter ("ignore" )
425+ nii = to_nifti (dataset , filename = filename , write_hmxfms = write_hmxfms , order = 1 )
426+
427+ # Check returned data
428+ assert isinstance (nii , nb .Nifti1Image )
429+ assert nii .shape == dataset .dataobj .shape
430+
431+ expected = dataset .dataobj .copy ()
432+ if motion_affines_present :
433+ # If motion affines are present, fake_apply adds the frame index to each
434+ # frame
435+ for i in range (n_frames ):
436+ expected [..., i ] = expected [..., i ] + i
437+ assert np .array_equal (nii .dataobj , expected ), (
438+ "Resampled data should reflect fake_apply modifications"
439+ )
440+ else :
441+ # No resampling; data should be identical to original
442+ assert np .array_equal (nii .dataobj , dataset .dataobj )
443+
444+ # Header behavior
445+ if datahdr_present :
446+ assert datahdr is not None
447+ assert nii .header .get_data_dtype () == datahdr .get_data_dtype ()
448+ else :
449+ xyzt = nii .header .get_xyzt_units ()
450+ assert xyzt [0 ].lower () == "mm"
451+
452+ # If filename was provided, file should exist and equal to transformed data
453+ if filename is not None :
454+ assert filename .is_file ()
455+ nii_load = load_api (filename , nb .Nifti1Image )
456+ assert np .array_equal (nii_load .get_fdata (), expected )
457+ else :
458+ assert not any (tmp_path .iterdir ()), "Directory is not empty"
459+
460+ # When motion_affines present and write_hmxfms True and filename provided,
461+ # x5 files should be written
462+ if motion_affines_present and write_hmxfms and filename is not None :
463+ # The same file is written at every iteration, so earlier transforms are
464+ # overwritten and only the last transform remains on disk
465+ found_x5 = list (tmp_path .glob ("*.x5" ))
466+ assert len (found_x5 ) == 1
467+ x5_path = filename .with_suffix ("" ).with_suffix (".x5" )
468+ assert x5_path .is_file ()
469+ content = x5_path .read_text ()
470+ assert content == f"transform-{ n_frames - 1 } "
471+ else :
472+ found_x5 = list (tmp_path .glob ("*.x5" ))
473+ assert len (found_x5 ) == 0
474+
475+
261476def test_load_hdf5 (random_dataset : BaseDataset ):
262477 """Test the 'load' function with an HDF5 file."""
263478 with TemporaryDirectory () as tmpdir :
0 commit comments