Skip to content

Commit a1cedb7

Browse files
committed
TST: Unit test data module load function
Unit test `data` module `load` function.
1 parent fc78dec commit a1cedb7

File tree

1 file changed

+272
-0
lines changed

1 file changed

+272
-0
lines changed

test/test_data.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
24+
import h5py
25+
import nibabel as nb
26+
import numpy as np
27+
import pytest
28+
29+
from nifreeze import data
30+
from nifreeze.data import dmri, pet
31+
32+
33+
def _raise_type(*args, **kwargs):
34+
raise TypeError("Cannot read")
35+
36+
37+
def test_load_hdf5_error(monkeypatch, tmp_path):
38+
fname = tmp_path / ("dataset" + data.NFDH5_EXT)
39+
40+
# All three dataclasses raise TypeError: load should raise TypeError
41+
monkeypatch.setattr(
42+
data.BaseDataset,
43+
"from_filename",
44+
classmethod(lambda _cls, fn: _raise_type()),
45+
raising=False,
46+
)
47+
monkeypatch.setattr(
48+
data.PET, "from_filename", classmethod(lambda _cls, fn: _raise_type()), raising=False
49+
)
50+
monkeypatch.setattr(
51+
data.DWI, "from_filename", classmethod(lambda _cls, fn: _raise_type()), raising=False
52+
)
53+
54+
with pytest.raises(TypeError, match="Could not read data"):
55+
data.load(fname)
56+
57+
58+
@pytest.mark.parametrize(
59+
"target, prior_failures",
60+
[
61+
(data.BaseDataset, []),
62+
(data.PET, [data.BaseDataset]),
63+
(data.DWI, [data.BaseDataset, data.PET]),
64+
],
65+
)
66+
def test_load_hdf5_sentinel(monkeypatch, tmp_path, target, prior_failures):
67+
fname = tmp_path / ("dataset" + data.NFDH5_EXT)
68+
69+
sentinel = object()
70+
71+
# Force earlier readers to raise TypeError so load() will try the target
72+
for cls in prior_failures:
73+
monkeypatch.setattr(
74+
cls, "from_filename", classmethod(lambda _cls, fn: _raise_type()), raising=False
75+
)
76+
77+
# Make the target reader return our sentinel
78+
monkeypatch.setattr(
79+
target, "from_filename", classmethod(lambda _cls, fn: sentinel), raising=False
80+
)
81+
82+
assert data.load(fname) is sentinel
83+
84+
85+
@pytest.mark.parametrize(
86+
"target, prior_failures, vol_size",
87+
[
88+
(data.BaseDataset, [], (4, 5, 6, 2)),
89+
(data.PET, [data.BaseDataset], (3, 4, 3, 5)),
90+
(data.DWI, [data.BaseDataset, data.PET], (2, 2, 6, 4)),
91+
],
92+
)
93+
def test_load_hdf5_data(request, tmp_path, monkeypatch, target, prior_failures, vol_size):
94+
"""
95+
For each target dataclass, write a tiny HDF5 file with random data, force
96+
earlier readers to raise TypeError, and monkeypatch the target's
97+
from_filename to read the HDF5 and return a small object containing the data
98+
so we can assert it was read.
99+
"""
100+
101+
rng = request.node.rng
102+
103+
# Create random arrays to write into the HDF5 file
104+
dataobj = rng.random(vol_size).astype(np.float32)
105+
affine = np.eye(4).astype(np.float64)
106+
brainmask_dataobj = rng.choice([True, False], size=dataobj.shape[:3]).astype(np.uint8)
107+
108+
fname = tmp_path / ("dataset" + data.NFDH5_EXT)
109+
110+
# Write a minimal HDF5 layout that our patched reader will understand
111+
with h5py.File(fname, "w") as f:
112+
f.create_dataset("dataobj", data=dataobj)
113+
f.create_dataset("affine", data=affine)
114+
f.create_dataset("brainmask", data=brainmask_dataobj)
115+
116+
# Force earlier readers to raise TypeError so load() will try the target
117+
for cls in prior_failures:
118+
monkeypatch.setattr(
119+
cls, "from_filename", classmethod(lambda _cls, fn: _raise_type()), raising=False
120+
)
121+
122+
# Define a reader that reads our HDF5 layout and returns a simple object
123+
def _from_filename(filename):
124+
with h5py.File(filename, "r") as _f:
125+
_dataobj = np.array(_f["dataobj"])
126+
_affine = np.array(_f["affine"])
127+
_brainmask = np.array(_f["brainmask"]).astype(bool)
128+
129+
class SimpleBaseDataset:
130+
def __init__(self, **kwargs):
131+
self.dataobj = kwargs["dataobj"]
132+
self.affine = kwargs["affine"]
133+
self.brainmask = None
134+
135+
obj = SimpleBaseDataset()
136+
# Mirror names that consumers expect
137+
obj.dataobj = _dataobj
138+
obj.affine = _affine
139+
obj.brainmask = _brainmask
140+
return obj
141+
142+
# Patch the target class's from_filename to use our reader
143+
monkeypatch.setattr(
144+
target,
145+
"from_filename",
146+
classmethod(lambda _cls, fn: _from_filename(fn)),
147+
raising=False,
148+
)
149+
150+
# Call the top-level loader and verify we got back the object with the data
151+
retval = data.load(fname)
152+
153+
# The returned object should have the attributes we set above
154+
assert hasattr(retval, "dataobj")
155+
assert hasattr(retval, "affine")
156+
assert hasattr(retval, "brainmask")
157+
158+
assert retval.dataobj is not None
159+
assert retval.dataobj.shape == dataobj.shape
160+
assert np.allclose(retval.dataobj, dataobj)
161+
162+
assert retval.affine is not None
163+
assert retval.affine.shape == affine.shape
164+
assert np.array_equal(retval.affine, affine)
165+
166+
assert retval.brainmask is not None
167+
assert retval.brainmask.shape == brainmask_dataobj.shape
168+
assert np.array_equal(retval.brainmask, brainmask_dataobj)
169+
170+
171+
@pytest.mark.random_uniform_spatial_data((5, 2, 4), 0.0, 1.0)
172+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
173+
@pytest.mark.parametrize(
174+
"use_brainmask, kwargs",
175+
[
176+
(True, {}),
177+
(False, {"data": 2.0}),
178+
],
179+
)
180+
def test_load_base_nifti(
181+
request, tmp_path, monkeypatch, setup_random_uniform_spatial_data, use_brainmask, kwargs
182+
):
183+
rng = request.node.rng
184+
dataobj, affine = setup_random_uniform_spatial_data
185+
img = nb.Nifti1Image(dataobj, affine)
186+
img_fname = tmp_path / "data.nii.gz"
187+
nb.save(img, img_fname)
188+
189+
brainmask_dataobj = np.ones(dataobj.shape[:3]).astype(bool)
190+
if use_brainmask:
191+
brainmask_dataobj = rng.choice([True, False], size=dataobj.shape[:3]).astype(bool)
192+
193+
brainmask = nb.Nifti1Image(brainmask_dataobj.astype(np.uint8), affine)
194+
brainmask_fname = tmp_path / "brainmask.nii.gz"
195+
nb.save(brainmask, brainmask_fname)
196+
197+
# Monkeypatch BaseDataset to a minimal holder class that mirrors the API
198+
class SimpleBaseDataset:
199+
def __init__(self, **kwargs):
200+
self.dataobj = kwargs["dataobj"]
201+
self.affine = kwargs["affine"]
202+
self.brainmask = None
203+
204+
monkeypatch.setattr(data, "BaseDataset", SimpleBaseDataset)
205+
206+
retval = data.load(img_fname, brainmask_file=brainmask_fname, **kwargs)
207+
208+
assert isinstance(retval, data.BaseDataset)
209+
210+
assert hasattr(retval, "dataobj")
211+
assert hasattr(retval, "brainmask")
212+
assert hasattr(retval, "affine")
213+
214+
assert retval.dataobj is not None
215+
assert np.allclose(retval.dataobj, dataobj)
216+
217+
assert retval.affine is not None
218+
assert np.allclose(retval.affine, affine)
219+
220+
assert retval.brainmask is not None
221+
assert np.array_equal(retval.brainmask, brainmask_dataobj)
222+
223+
224+
def test_load_dmri_from_nii(monkeypatch, tmp_path):
225+
fname = tmp_path / "image.nii"
226+
mask = tmp_path / "mask.nii"
227+
228+
called = {}
229+
sentinel = object()
230+
231+
def dummy_from_nii(filename, brainmask_file=None, **kwargs):
232+
called["filename"] = filename
233+
called["brainmask_file"] = brainmask_file
234+
called["kwargs"] = kwargs
235+
return sentinel
236+
237+
monkeypatch.setattr(dmri, "from_nii", dummy_from_nii)
238+
239+
res = data.load(fname, brainmask_file=mask, gradients_file="grad.txt", bvec_file="bvecs.txt")
240+
241+
assert res is sentinel
242+
assert called["filename"] == fname
243+
assert called["brainmask_file"] == mask
244+
assert "gradients_file" in called["kwargs"]
245+
assert (
246+
"bvec_file" in called["kwargs"]
247+
or "bvecs_file" in called["kwargs"]
248+
or "bvecs" in called["kwargs"]
249+
)
250+
251+
252+
def test_load_pet_from_nii(monkeypatch, tmp_path):
253+
fname = tmp_path / "image.nii"
254+
mask = tmp_path / "mask.nii"
255+
256+
called = {}
257+
sentinel = object()
258+
259+
def dummy_from_nii(filename, brainmask_file=None, **kwargs):
260+
called["filename"] = filename
261+
called["brainmask_file"] = brainmask_file
262+
called["kwargs"] = kwargs
263+
return sentinel
264+
265+
monkeypatch.setattr(pet, "from_nii", dummy_from_nii)
266+
267+
retval = data.load(fname, brainmask_file=mask, temporal_file="temporal.txt")
268+
269+
assert retval is sentinel
270+
assert called["filename"] == fname
271+
assert called["brainmask_file"] == mask
272+
assert "temporal_file" in called["kwargs"]

0 commit comments

Comments
 (0)