3838from nitransforms .resampling import apply
3939from typing_extensions import Self
4040
41- from nifreeze .data .base import BaseDataset , _cmp , _data_repr
42- from nifreeze .utils .ndimage import load_api
41+ from nifreeze .data .base import BaseDataset , _cmp , _data_repr , _has_ndim
42+ from nifreeze .utils .ndimage import get_data , load_api
43+
44+ ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+ """PET initialization array attribute absence error message."""
46+
47+ ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = (
48+ "PET '{attribute}' must be a numeric homogeneous array-like object."
49+ )
50+ """PET initialization array attribute object error message."""
51+
52+ ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
53+ """PET initialization array attribute ndim error message."""
54+
55+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
56+ PET '{attribute}' length does not match number of frames: \
57+ expected {n_frames} values, found {attr_len}."""
58+ """PET attribute shape mismatch error message."""
59+
60+
61+ def format_array_like (value : Any , attr : attrs .Attribute ) -> np .ndarray :
62+ """Validates that ``value`` can be converted to a :obj:`~numpy.ndarray`
63+
64+ This function is intended for use as an attrs-style formatter.
65+
66+ Parameters
67+ ----------
68+ value : :obj:`Any`
69+ The value to format.
70+ attr : :obj:`~attrs.Attribute`
71+ The attribute being initialized; ``attr.name`` is used in the error message.
72+
73+ Returns
74+ -------
75+ formatted : :obj:`~numpy.ndarray`
76+ The formatted value.
77+
78+ Raises
79+ ------
80+ exc:`TypeError`
81+ If the input cannot be converted to a float :obj:`~numpy.ndarray`.
82+ exc:`ValueError`
83+ If the value is ``None``.
84+ """
85+
86+ if value is None :
87+ raise ValueError (ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG .format (attribute = attr .name ))
88+
89+ try :
90+ formatted = np .asarray (value , dtype = float )
91+ except (TypeError , ValueError ) as exc :
92+ # Conversion failed (e.g. nested ragged objects, non-numeric)
93+ raise TypeError (ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG .format (attribute = attr .name )) from exc
94+
95+ return formatted
96+
97+
98+ def validate_1d_array (inst : PET , attr : attrs .Attribute , value : Any ) -> None :
99+ """Strict validator to ensure an attribute is a 1D NumPy array.
100+
101+ Enforces that ``value`` has exactly one dimension (``value.ndim == 1``).
102+
103+ This function is intended for use as an attrs-style validator.
104+
105+ Parameters
106+ ----------
107+ inst : :obj:`~nifreeze.data.pet.PET`
108+ The instance being validated (unused; present for validator signature).
109+ attr : :obj:`~attrs.Attribute`
110+ The attribute being validated; ``attr.name`` is used in the error message.
111+ value : :obj:`Any`
112+ The value to validate.
113+
114+ Raises
115+ ------
116+ exc:`ValueError`
117+ If the value is not 1D.
118+ """
119+
120+ if not _has_ndim (value , 1 ):
121+ raise ValueError (ARRAY_ATTRIBUTE_NDIM_ERROR_MSG .format (attribute = attr .name ))
43122
44123
45124@attrs .define (slots = True )
46125class PET (BaseDataset [np .ndarray ]):
47- """Data representation structure for PET data."""
126+ """Data representation structure for PET data.
127+
128+ Relevant temporal attributes, namely the per-frame duration and midframe
129+ times and the total duration, are computed at initialization.
130+
131+ Let :math:`K` be the number of frames. For each frame :math:`k`, we define
132+ the frame duration :math:`d_k`as the difference between consecutive midframe
133+ times:
134+
135+ .. math::
136+ d_k = t^{\\ mathrm{end}}_k - t^{\\ mathrm{start}}_k
137+
138+ In this implementation, the last interval is duplicated to match the
139+ appropriate dimensionality.
140+
141+ The total duration :math:`D` of the acquisition is a scalar computed as the
142+ sum of the frame durations:
143+
144+ .. math::
145+ D = \\ sum_{k=1}^{K} d_k
146+ = \\ sum_{k=1}^{K} \\ left(t^{\\ mathrm{end}}_k - t^{\\ mathrm{start}}_k\\ right)
147+
148+ One commonly useful scalar time is the overall mid-frame time (the midpoint
149+ of the whole acquisition). If :math:`t_{\\ mathrm{first}}` denotes the start
150+ time of the first frame, the overall midframe (the midpoint of the whole
151+ acquisition) is
48152
49- midframe : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
153+ .. math::
154+ \\ mathrm{midframe} = t_{\\ mathrm{first}} + \\ frac{D}{2}
155+
156+ Per-frame midpoints :math:`m_k` are exposed for convenience:
157+
158+ .. math::
159+ m_k = t^{\\ mathrm{start}}_k + \\ frac{d_k}{2}
160+
161+ Users can provide their own frame duration data to be used instead of the
162+ default values computed as shown above. See :meth:`~nifreeze.data.pet._compute_frame_duration`.
163+
164+ """
165+
166+ frame_time : np .ndarray = attrs .field (
167+ default = None ,
168+ repr = _data_repr ,
169+ eq = attrs .cmp_using (eq = _cmp ),
170+ converter = attrs .Converter (format_array_like , takes_field = True ),
171+ validator = validate_1d_array ,
172+ )
173+ """A (N,) numpy array specifying the timing of each sample or frame."""
174+ uptake : np .ndarray = attrs .field (
175+ default = None ,
176+ repr = _data_repr ,
177+ eq = attrs .cmp_using (eq = _cmp ),
178+ converter = attrs .Converter (format_array_like , takes_field = True ),
179+ validator = validate_1d_array ,
180+ )
181+ """A (N,) numpy array specifying the uptake value of each sample or frame."""
182+ frame_duration : np .ndarray | None = attrs .field (
183+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp )
184+ )
185+ """A (N,) numpy array specifying the frame duration."""
186+ midframe : np .ndarray = attrs .field (
187+ default = None , repr = _data_repr , init = False , eq = attrs .cmp_using (eq = _cmp )
188+ )
50189 """A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51- total_duration : float = attrs .field (default = None , repr = True )
190+ total_duration : float = attrs .field (default = None , repr = True , init = False )
52191 """A float representing the total duration of the dataset."""
53- uptake : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
54- """A (N,) numpy array specifying the uptake value of each sample or frame."""
192+
193+ def __attrs_post_init__ (self ) -> None :
194+ """Enforce presence and basic consistency of PET data fields at
195+ instantiation time.
196+
197+ Specifically, the length of the frame_time and uptake attributes must
198+ match the last dimension of the data (number of frames).
199+
200+ Computes the values for the private attributes.
201+ """
202+ n_frames = int (self .dataobj .shape [- 1 ])
203+
204+ if len (self .frame_time ) != n_frames :
205+ raise ValueError (
206+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR .format (
207+ attribute = attrs .fields_dict (self .__class__ )["frame_time" ].name ,
208+ n_frames = n_frames ,
209+ attr_len = len (self .frame_time ),
210+ )
211+ )
212+
213+ if len (self .uptake ) != n_frames :
214+ raise ValueError (
215+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR .format (
216+ attribute = attrs .fields_dict (self .__class__ )["uptake" ].name ,
217+ n_frames = n_frames ,
218+ attr_len = len (self .uptake ),
219+ )
220+ )
221+
222+ # Compute temporal attributes
223+
224+ # Convert to a float32 numpy array and zero out the earliest time
225+ frame_time_arr = np .array (self .frame_time , dtype = np .float32 )
226+ frame_time_arr -= frame_time_arr [0 ]
227+
228+ # If the user did not provide frame duration values, compute them
229+ if self .frame_duration is not None :
230+ self .frame_duration = np .array (self .frame_duration , dtype = np .float32 )
231+ else :
232+ self .frame_duration = _compute_frame_duration (frame_time_arr )
233+
234+ # Compute total duration and shift midframe to the midpoint
235+ self .total_duration = float (frame_time_arr [- 1 ] + self .frame_duration [- 1 ])
236+ self .midframe = frame_time_arr + 0.5 * self .frame_duration
55237
56238 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
57239 return (self .midframe [idx ],)
@@ -222,6 +404,8 @@ def from_nii(
222404 frame_time : np .ndarray | list [float ],
223405 brainmask_file : Path | str | None = None ,
224406 frame_duration : np .ndarray | list [float ] | None = None ,
407+ uptake : np .ndarray | list [float ] | None = None ,
408+ uptake_stat_func : Callable [..., np .ndarray ] = np .sum ,
225409) -> PET :
226410 """
227411 Load PET data from NIfTI, creating a PET object with appropriate metadata.
@@ -236,9 +420,12 @@ def from_nii(
236420 A brainmask NIfTI file. If provided, will be loaded and
237421 stored in the returned dataset.
238422 frame_duration : :obj:`numpy.ndarray` or :obj:`list` of :obj:`float`, optional
239- The duration of each frame.
240- If ``None``, it is derived by the difference of consecutive frame times,
241- defaulting the last frame to match the second-last.
423+ The duration of each frame. If ``None``, its computation is deferred to
424+ the :obj:`~nifreeze.data.pet.PET` object initialization.
425+ uptake : :obj:`numpy.ndarray` or :obj:`list` of :obj:`float`, optional
426+ Uptake values. If provided, it ``uptake_stat_func`` will be ignored.
427+ uptake_stat_func : :obj:`Callable`, optional
428+ The statistic function to compute the uptake value.
242429
243430 Returns
244431 -------
@@ -253,37 +440,34 @@ def from_nii(
253440 """
254441
255442 filename = Path (filename )
256- # Load from NIfTI
257- img = load_api (filename , SpatialImage )
258- data = img .get_fdata (dtype = np .float32 )
259- pet_obj = PET (
260- dataobj = data ,
261- affine = img .affine ,
262- )
263-
264- pet_obj .uptake = _compute_uptake_statistic (data , stat_func = np .sum )
265443
266- # Convert to a float32 numpy array and zero out the earliest time
267- frame_time_arr = np .array (frame_time , dtype = np .float32 )
268- frame_time_arr -= frame_time_arr [0 ]
269- pet_obj .midframe = frame_time_arr
444+ # 1) Load a NIfTI
445+ img = load_api (filename , SpatialImage )
446+ fulldata = get_data (img )
270447
271- # If the user doesn't provide frame_duration, we derive it:
272- if frame_duration is None :
273- durations = _compute_frame_duration ( pet_obj . midframe )
448+ # 2) Determine uptake value
449+ if uptake is not None :
450+ pass
274451 else :
275- durations = np . array ( frame_duration , dtype = np . float32 )
452+ uptake = _compute_uptake_statistic ( fulldata , stat_func = uptake_stat_func )
276453
277- # Set total_duration and shift frame_time to the midpoint
278- pet_obj .total_duration = float (frame_time_arr [- 1 ] + durations [- 1 ])
279- pet_obj .midframe = frame_time_arr + 0.5 * durations
454+ uptake = np .asarray (uptake )
280455
281- # If a brain mask is provided, load and attach
456+ # 3) If a brainmask_file was provided, load it
457+ brainmask_data = None
282458 if brainmask_file is not None :
283459 mask_img = load_api (brainmask_file , SpatialImage )
284- pet_obj . brainmask = np .asanyarray (mask_img .dataobj , dtype = bool )
460+ brainmask_data = np .asanyarray (mask_img .dataobj , dtype = bool )
285461
286- return pet_obj
462+ # 4) Create and return the PET instance
463+ return PET (
464+ dataobj = fulldata ,
465+ affine = img .affine ,
466+ brainmask = brainmask_data ,
467+ frame_time = np .asarray (frame_time ),
468+ frame_duration = frame_duration if frame_duration is None else np .asarray (frame_duration ),
469+ uptake = uptake ,
470+ )
287471
288472
289473def _compute_frame_duration (midframe : np .ndarray ) -> np .ndarray :
@@ -308,7 +492,7 @@ def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:
308492 return durations
309493
310494
311- def _compute_uptake_statistic (data : np .ndarray , stat_func : Callable = np .sum ):
495+ def _compute_uptake_statistic (data : np .ndarray , stat_func : Callable [..., np . ndarray ] = np .sum ):
312496 """Compute a statistic over all voxels for each frame on a PET sequence.
313497
314498 Assumes the last dimension corresponds to the number of frames in the
0 commit comments