Skip to content

Commit 37ed54c

Browse files
committed
ENH: Validate PET data objects' attributes at instantiation
Validate PET data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities. Refactor the PET attributes so that only the required (`frame_time` and `uptake`) and optional (`frame_duration`) parameters are accepted by the constructor. The `midframe` and the `total_duration` attributes can be computed from the required parameters, so exclude them from `__init__`. Although `uptake` can also be computed from the PET frame data, the rationale behind requiring it is similar to the one for the DWI class `bzero`: users will be able to compute the `uptake` using their preferred strategy and provide it to the constructor. For the `from_nii` function, if a callable is provided, it will be used to compute the value; otherwise a default strategy is used to compute it. Refactor the `from_nii` function so that the required parameters are present when instantiating the PET instance. Increase consistency with the `dmri` data module `from_nii` counterpart function. Use the `get_data` utils function in `from_nii` to handle automatically the data type when loading the PET data. Refactor the PET data creation fixture in `conftest.py` to accept the required/optional arguments and to return the necessary data. Refactor the tests accordingly and increase consistency with the `dmri` data module testing helper functions. Reduces cognitive load and maintenance burden. Add additional object instantiation equality checks: check that objects intantiated through reading NIfTI files equal objects instantiated directly.
1 parent 62e5e43 commit 37ed54c

File tree

3 files changed

+526
-55
lines changed

3 files changed

+526
-55
lines changed

src/nifreeze/data/pet.py

Lines changed: 218 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,202 @@
3838
from nitransforms.resampling import apply
3939
from typing_extensions import Self
4040

41-
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
42-
from nifreeze.utils.ndimage import load_api
41+
from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim
42+
from nifreeze.utils.ndimage import get_data, load_api
43+
44+
ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+
"""PET initialization array attribute absence error message."""
46+
47+
ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = (
48+
"PET '{attribute}' must be a numeric homogeneous array-like object."
49+
)
50+
"""PET initialization array attribute object error message."""
51+
52+
ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
53+
"""PET initialization array attribute ndim error message."""
54+
55+
ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
56+
PET '{attribute}' length does not match number of frames: \
57+
expected {n_frames} values, found {attr_len}."""
58+
"""PET attribute shape mismatch error message."""
59+
60+
61+
def format_array_like(value: Any, attr: attrs.Attribute) -> np.ndarray:
62+
"""Validates that ``value`` can be converted to a :obj:`~numpy.ndarray`
63+
64+
This function is intended for use as an attrs-style formatter.
65+
66+
Parameters
67+
----------
68+
value : :obj:`Any`
69+
The value to format.
70+
attr : :obj:`~attrs.Attribute`
71+
The attribute being initialized; ``attr.name`` is used in the error message.
72+
73+
Returns
74+
-------
75+
formatted : :obj:`~numpy.ndarray`
76+
The formatted value.
77+
78+
Raises
79+
------
80+
exc:`TypeError`
81+
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
82+
exc:`ValueError`
83+
If the value is ``None``.
84+
"""
85+
86+
if value is None:
87+
raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name))
88+
89+
try:
90+
formatted = np.asarray(value, dtype=float)
91+
except (TypeError, ValueError) as exc:
92+
# Conversion failed (e.g. nested ragged objects, non-numeric)
93+
raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) from exc
94+
95+
return formatted
96+
97+
98+
def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None:
99+
"""Strict validator to ensure an attribute is a 1D NumPy array.
100+
101+
Enforces that ``value`` has exactly one dimension (``value.ndim == 1``).
102+
103+
This function is intended for use as an attrs-style validator.
104+
105+
Parameters
106+
----------
107+
inst : :obj:`~nifreeze.data.pet.PET`
108+
The instance being validated (unused; present for validator signature).
109+
attr : :obj:`~attrs.Attribute`
110+
The attribute being validated; ``attr.name`` is used in the error message.
111+
value : :obj:`Any`
112+
The value to validate.
113+
114+
Raises
115+
------
116+
exc:`ValueError`
117+
If the value is not 1D.
118+
"""
119+
120+
if not _has_ndim(value, 1):
121+
raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name))
43122

44123

45124
@attrs.define(slots=True)
46125
class PET(BaseDataset[np.ndarray]):
47-
"""Data representation structure for PET data."""
126+
"""Data representation structure for PET data.
127+
128+
Relevant temporal attributes, namely the per-frame duration and midframe
129+
times and the total duration, are computed at initialization.
130+
131+
Let :math:`K` be the number of frames. For each frame :math:`k`, we define
132+
the frame duration :math:`d_k`as the difference between consecutive midframe
133+
times:
134+
135+
.. math::
136+
d_k = t^{\\mathrm{end}}_k - t^{\\mathrm{start}}_k
137+
138+
In this implementation, the last interval is duplicated to match the
139+
appropriate dimensionality.
140+
141+
The total duration :math:`D` of the acquisition is a scalar computed as the
142+
sum of the frame durations:
143+
144+
.. math::
145+
D = \\sum_{k=1}^{K} d_k
146+
= \\sum_{k=1}^{K} \\left(t^{\\mathrm{end}}_k - t^{\\mathrm{start}}_k\\right)
147+
148+
One commonly useful scalar time is the overall mid-frame time (the midpoint
149+
of the whole acquisition). If :math:`t_{\\mathrm{first}}` denotes the start
150+
time of the first frame, the overall midframe (the midpoint of the whole
151+
acquisition) is
48152
49-
midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
153+
.. math::
154+
\\mathrm{midframe} = t_{\\mathrm{first}} + \\frac{D}{2}
155+
156+
Per-frame midpoints :math:`m_k` are exposed for convenience:
157+
158+
.. math::
159+
m_k = t^{\\mathrm{start}}_k + \\frac{d_k}{2}
160+
161+
Users can provide their own frame duration data to be used instead of the
162+
default values computed as shown above. See :meth:`~nifreeze.data.pet._compute_frame_duration`.
163+
164+
"""
165+
166+
frame_time: np.ndarray = attrs.field(
167+
default=None,
168+
repr=_data_repr,
169+
eq=attrs.cmp_using(eq=_cmp),
170+
converter=attrs.Converter(format_array_like, takes_field=True),
171+
validator=validate_1d_array,
172+
)
173+
"""A (N,) numpy array specifying the timing of each sample or frame."""
174+
uptake: np.ndarray = attrs.field(
175+
default=None,
176+
repr=_data_repr,
177+
eq=attrs.cmp_using(eq=_cmp),
178+
converter=attrs.Converter(format_array_like, takes_field=True),
179+
validator=validate_1d_array,
180+
)
181+
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
182+
frame_duration: np.ndarray | None = attrs.field(
183+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
184+
)
185+
"""A (N,) numpy array specifying the frame duration."""
186+
midframe: np.ndarray = attrs.field(
187+
default=None, repr=_data_repr, init=False, eq=attrs.cmp_using(eq=_cmp)
188+
)
50189
"""A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51-
total_duration: float = attrs.field(default=None, repr=True)
190+
total_duration: float = attrs.field(default=None, repr=True, init=False)
52191
"""A float representing the total duration of the dataset."""
53-
uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
54-
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
192+
193+
def __attrs_post_init__(self) -> None:
194+
"""Enforce presence and basic consistency of PET data fields at
195+
instantiation time.
196+
197+
Specifically, the length of the frame_time and uptake attributes must
198+
match the last dimension of the data (number of frames).
199+
200+
Computes the values for the private attributes.
201+
"""
202+
n_frames = int(self.dataobj.shape[-1])
203+
204+
if len(self.frame_time) != n_frames:
205+
raise ValueError(
206+
ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format(
207+
attribute=attrs.fields_dict(self.__class__)["frame_time"].name,
208+
n_frames=n_frames,
209+
attr_len=len(self.frame_time),
210+
)
211+
)
212+
213+
if len(self.uptake) != n_frames:
214+
raise ValueError(
215+
ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format(
216+
attribute=attrs.fields_dict(self.__class__)["uptake"].name,
217+
n_frames=n_frames,
218+
attr_len=len(self.uptake),
219+
)
220+
)
221+
222+
# Compute temporal attributes
223+
224+
# Convert to a float32 numpy array and zero out the earliest time
225+
frame_time_arr = np.array(self.frame_time, dtype=np.float32)
226+
frame_time_arr -= frame_time_arr[0]
227+
228+
# If the user did not provide frame duration values, compute them
229+
if self.frame_duration is not None:
230+
self.frame_duration = np.array(self.frame_duration, dtype=np.float32)
231+
else:
232+
self.frame_duration = _compute_frame_duration(frame_time_arr)
233+
234+
# Compute total duration and shift midframe to the midpoint
235+
self.total_duration = float(frame_time_arr[-1] + self.frame_duration[-1])
236+
self.midframe = frame_time_arr + 0.5 * self.frame_duration
55237

56238
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
57239
return (self.midframe[idx],)
@@ -222,6 +404,8 @@ def from_nii(
222404
frame_time: np.ndarray | list[float],
223405
brainmask_file: Path | str | None = None,
224406
frame_duration: np.ndarray | list[float] | None = None,
407+
uptake: np.ndarray | list[float] | None = None,
408+
uptake_stat_func: Callable[..., np.ndarray] = np.sum,
225409
) -> PET:
226410
"""
227411
Load PET data from NIfTI, creating a PET object with appropriate metadata.
@@ -236,9 +420,12 @@ def from_nii(
236420
A brainmask NIfTI file. If provided, will be loaded and
237421
stored in the returned dataset.
238422
frame_duration : :obj:`numpy.ndarray` or :obj:`list` of :obj:`float`, optional
239-
The duration of each frame.
240-
If ``None``, it is derived by the difference of consecutive frame times,
241-
defaulting the last frame to match the second-last.
423+
The duration of each frame. If ``None``, its computation is deferred to
424+
the :obj:`~nifreeze.data.pet.PET` object initialization.
425+
uptake : :obj:`numpy.ndarray` or :obj:`list` of :obj:`float`, optional
426+
Uptake values. If provided, it ``uptake_stat_func`` will be ignored.
427+
uptake_stat_func : :obj:`Callable`, optional
428+
The statistic function to compute the uptake value.
242429
243430
Returns
244431
-------
@@ -253,37 +440,34 @@ def from_nii(
253440
"""
254441

255442
filename = Path(filename)
256-
# Load from NIfTI
257-
img = load_api(filename, SpatialImage)
258-
data = img.get_fdata(dtype=np.float32)
259-
pet_obj = PET(
260-
dataobj=data,
261-
affine=img.affine,
262-
)
263-
264-
pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum)
265443

266-
# Convert to a float32 numpy array and zero out the earliest time
267-
frame_time_arr = np.array(frame_time, dtype=np.float32)
268-
frame_time_arr -= frame_time_arr[0]
269-
pet_obj.midframe = frame_time_arr
444+
# 1) Load a NIfTI
445+
img = load_api(filename, SpatialImage)
446+
fulldata = get_data(img)
270447

271-
# If the user doesn't provide frame_duration, we derive it:
272-
if frame_duration is None:
273-
durations = _compute_frame_duration(pet_obj.midframe)
448+
# 2) Determine uptake value
449+
if uptake is not None:
450+
pass
274451
else:
275-
durations = np.array(frame_duration, dtype=np.float32)
452+
uptake = _compute_uptake_statistic(fulldata, stat_func=uptake_stat_func)
276453

277-
# Set total_duration and shift frame_time to the midpoint
278-
pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1])
279-
pet_obj.midframe = frame_time_arr + 0.5 * durations
454+
uptake = np.asarray(uptake)
280455

281-
# If a brain mask is provided, load and attach
456+
# 3) If a brainmask_file was provided, load it
457+
brainmask_data = None
282458
if brainmask_file is not None:
283459
mask_img = load_api(brainmask_file, SpatialImage)
284-
pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
460+
brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool)
285461

286-
return pet_obj
462+
# 4) Create and return the PET instance
463+
return PET(
464+
dataobj=fulldata,
465+
affine=img.affine,
466+
brainmask=brainmask_data,
467+
frame_time=np.asarray(frame_time),
468+
frame_duration=frame_duration if frame_duration is None else np.asarray(frame_duration),
469+
uptake=uptake,
470+
)
287471

288472

289473
def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:
@@ -308,7 +492,7 @@ def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:
308492
return durations
309493

310494

311-
def _compute_uptake_statistic(data: np.ndarray, stat_func: Callable = np.sum):
495+
def _compute_uptake_statistic(data: np.ndarray, stat_func: Callable[..., np.ndarray] = np.sum):
312496
"""Compute a statistic over all voxels for each frame on a PET sequence.
313497
314498
Assumes the last dimension corresponds to the number of frames in the

test/conftest.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,11 @@ def setup_random_pet_data(request):
323323

324324
n_frames = 5
325325
vol_size = (4, 4, 4)
326-
midframe = np.arange(n_frames, dtype=np.float32) + 1
327-
total_duration = float(n_frames + 1)
326+
frame_time = np.arange(n_frames, dtype=np.float32) + 1
327+
uptake_stat_func = np.sum
328+
frame_duration = None
328329
if marker:
329-
n_frames, vol_size, midframe, total_duration = marker.args
330+
n_frames, vol_size, frame_time, uptake_stat_func, frame_duration = marker.args
330331

331332
rng = request.node.rng
332333

@@ -335,10 +336,13 @@ def setup_random_pet_data(request):
335336
)
336337
brainmask_dataobj = rng.choice([True, False], size=vol_size).astype(bool)
337338

339+
uptake = uptake_stat_func(pet_dataobj.reshape(-1, pet_dataobj.shape[-1]), axis=0)
340+
338341
return (
339342
pet_dataobj,
340343
affine,
341344
brainmask_dataobj,
342-
midframe,
343-
total_duration,
345+
frame_time,
346+
uptake,
347+
frame_duration,
344348
)

0 commit comments

Comments
 (0)