3838from nitransforms .resampling import apply
3939from typing_extensions import Self
4040
41- from nifreeze .data .base import BaseDataset , _cmp , _data_repr
42- from nifreeze .utils .ndimage import load_api
41+ from nifreeze .data .base import BaseDataset , _cmp , _data_repr , _has_ndim
42+ from nifreeze .utils .ndimage import get_data , load_api
43+
44+ ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+ """PET initialization array attribute absence error message."""
46+
47+ ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array."
48+ """PET initialization array attribute object error message."""
49+
50+ ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
51+ """PET initialization array attribute ndim error message."""
52+
53+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
54+ PET '{attribute}' length does not match number of frames: \
55+ expected {n_frames} values, found {attr_len}."""
56+ """PET attribute shape mismatch error message."""
57+
58+
59+ def validate_1d_array (inst : PET , attr : attrs .Attribute , value : Any ) -> None :
60+ """Strict validator to ensure an attribute is a 1D NumPy array.
61+
62+ Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly
63+ one dimension (``value.ndim == 1``).
64+
65+ This function is intended for use as an attrs-style validator.
66+
67+ Parameters
68+ ----------
69+ inst : :obj:`~nifreeze.data.pet.PET`
70+ The instance being validated (unused; present for validator signature).
71+ attr : :obj:`~attrs.Attribute`
72+ The attribute being validated; ``attr.name`` is used in the error message.
73+ value : :obj:`Any`
74+ The value to validate.
75+
76+ Raises
77+ ------
78+ exc:`TypeError`
79+ If the input cannot be converted to a float :obj:`~numpy.ndarray`.
80+ exc:`ValueError`
81+ If the value is ``None``, or not 1D.
82+ """
83+
84+ if value is None :
85+ raise ValueError (ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG .format (attribute = attr .name ))
86+
87+ if not isinstance (value , np .ndarray ):
88+ raise TypeError (ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG .format (attribute = attr .name ))
89+
90+ if not _has_ndim (value , 1 ):
91+ raise ValueError (ARRAY_ATTRIBUTE_NDIM_ERROR_MSG .format (attribute = attr .name ))
4392
4493
4594@attrs .define (slots = True )
4695class PET (BaseDataset [np .ndarray ]):
47- """Data representation structure for PET data."""
96+ """Data representation structure for PET data.
97+
98+ If not provided, frame duration data are computed as differences between
99+ consecutive midframe times. The last interval is duplicated. See
100+ :meth:`~nifreeze.data.pet._compute_frame_duration`.
101+ """
48102
49- midframe : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
103+ frame_time : np .ndarray = attrs .field (
104+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ), validator = validate_1d_array
105+ )
106+ """A (N,) numpy array specifying the timing of each sample or frame."""
107+ uptake : np .ndarray = attrs .field (
108+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ), validator = validate_1d_array
109+ )
110+ """A (N,) numpy array specifying the uptake value of each sample or frame."""
111+ frame_duration : np .ndarray | None = attrs .field (
112+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp )
113+ )
114+ """A (N,) numpy array specifying the frame duration."""
115+ midframe : np .ndarray = attrs .field (
116+ default = None , repr = _data_repr , init = False , eq = attrs .cmp_using (eq = _cmp )
117+ )
50118 """A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51- total_duration : float = attrs .field (default = None , repr = True )
119+ total_duration : float = attrs .field (default = None , repr = True , init = False )
52120 """A float representing the total duration of the dataset."""
53- uptake : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
54- """A (N,) numpy array specifying the uptake value of each sample or frame."""
121+
122+ def __attrs_post_init__ (self ) -> None :
123+ """Enforce presence and basic consistency of PET data fields at
124+ instantiation time.
125+
126+ Specifically, the length of the frame_time and uptake attributes must
127+ match the last dimension of the data (number of frames).
128+
129+ Computes the values for the private attributes.
130+ """
131+ n_frames = int (self .dataobj .shape [- 1 ])
132+
133+ if len (self .frame_time ) != n_frames :
134+ raise ValueError (
135+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR .format (
136+ attribute = attrs .fields_dict (self .__class__ )["frame_time" ].name ,
137+ n_frames = n_frames ,
138+ attr_len = len (self .frame_time ),
139+ )
140+ )
141+
142+ if len (self .uptake ) != n_frames :
143+ raise ValueError (
144+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR .format (
145+ attribute = attrs .fields_dict (self .__class__ )["uptake" ].name ,
146+ n_frames = n_frames ,
147+ attr_len = len (self .uptake ),
148+ )
149+ )
150+
151+ # Compute temporal attributes
152+
153+ # Convert to a float32 numpy array and zero out the earliest time
154+ frame_time_arr = np .array (self .frame_time , dtype = np .float32 )
155+ frame_time_arr -= frame_time_arr [0 ]
156+
157+ # If the user did not provide frame duration values, compute them
158+ if self .frame_duration :
159+ durations = np .array (self .frame_duration , dtype = np .float32 )
160+ else :
161+ durations = _compute_frame_duration (frame_time_arr )
162+
163+ # Compute total duration and shift midframe to the midpoint
164+ self .total_duration = float (frame_time_arr [- 1 ] + durations [- 1 ])
165+ self .midframe = frame_time_arr + 0.5 * durations
55166
56167 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
57168 return (self .midframe [idx ],)
@@ -222,6 +333,8 @@ def from_nii(
222333 frame_time : np .ndarray | list [float ],
223334 brainmask_file : Path | str | None = None ,
224335 frame_duration : np .ndarray | list [float ] | None = None ,
336+ uptake : np .ndarray | list [float ] | None = None ,
337+ uptake_stat_func : Callable [..., np .ndarray ] = np .sum ,
225338) -> PET :
226339 """
227340 Load PET data from NIfTI, creating a PET object with appropriate metadata.
@@ -236,9 +349,12 @@ def from_nii(
236349 A brainmask NIfTI file. If provided, will be loaded and
237350 stored in the returned dataset.
238351 frame_duration : :obj:`numpy.ndarray` or :obj:`list` of :obj:`float`, optional
239- The duration of each frame.
240- If ``None``, it is derived by the difference of consecutive frame times,
241- defaulting the last frame to match the second-last.
352+ The duration of each frame. If ``None``, its computation is deferred to
353+ the :obj:`~nifreeze.data.pet.PET` object initialization.
354+ uptake : :obj:`numpy.ndarray` or :obj:`list` of :obj:`float`, optional
355+ Uptake values. If provided, it ``uptake_stat_func`` will be ignored.
356+ uptake_stat_func : :obj:`Callable`, optional
357+ The statistic function to compute the uptake value.
242358
243359 Returns
244360 -------
@@ -253,37 +369,32 @@ def from_nii(
253369 """
254370
255371 filename = Path (filename )
256- # Load from NIfTI
257- img = load_api (filename , SpatialImage )
258- data = img .get_fdata (dtype = np .float32 )
259- pet_obj = PET (
260- dataobj = data ,
261- affine = img .affine ,
262- )
263-
264- pet_obj .uptake = _compute_uptake_statistic (data , stat_func = np .sum )
265372
266- # Convert to a float32 numpy array and zero out the earliest time
267- frame_time_arr = np .array (frame_time , dtype = np .float32 )
268- frame_time_arr -= frame_time_arr [0 ]
269- pet_obj .midframe = frame_time_arr
373+ # 1) Load a NIfTI
374+ img = load_api (filename , SpatialImage )
375+ fulldata = get_data (img )
270376
271- # If the user doesn't provide frame_duration, we derive it:
272- if frame_duration is None :
273- durations = _compute_frame_duration ( pet_obj . midframe )
377+ # 2) Determine uptake value
378+ if uptake is not None :
379+ uptake = np . asarray ( uptake )
274380 else :
275- durations = np .array (frame_duration , dtype = np .float32 )
276-
277- # Set total_duration and shift frame_time to the midpoint
278- pet_obj .total_duration = float (frame_time_arr [- 1 ] + durations [- 1 ])
279- pet_obj .midframe = frame_time_arr + 0.5 * durations
381+ uptake = _compute_uptake_statistic (fulldata , stat_func = uptake_stat_func )
280382
281- # If a brain mask is provided, load and attach
383+ # 3) If a brainmask_file was provided, load it
384+ brainmask_data = None
282385 if brainmask_file is not None :
283386 mask_img = load_api (brainmask_file , SpatialImage )
284- pet_obj . brainmask = np .asanyarray (mask_img .dataobj , dtype = bool )
387+ brainmask_data = np .asanyarray (mask_img .dataobj , dtype = bool )
285388
286- return pet_obj
389+ # 4) Create and return the PET instance
390+ return PET (
391+ dataobj = fulldata ,
392+ affine = img .affine ,
393+ brainmask = brainmask_data ,
394+ frame_time = np .asarray (frame_time ),
395+ frame_duration = frame_duration if frame_duration is None else np .asarray (frame_duration ),
396+ uptake = uptake ,
397+ )
287398
288399
289400def _compute_frame_duration (midframe : np .ndarray ) -> np .ndarray :
@@ -308,7 +419,7 @@ def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:
308419 return durations
309420
310421
311- def _compute_uptake_statistic (data : np .ndarray , stat_func : Callable = np .sum ):
422+ def _compute_uptake_statistic (data : np .ndarray , stat_func : Callable [..., np . ndarray ] = np .sum ):
312423 """Compute a statistic over all voxels for each frame on a PET sequence.
313424
314425 Assumes the last dimension corresponds to the number of frames in the
0 commit comments