From 40b7add67bd13df61ab080cbad6ce8c313b52951 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Wed, 3 Dec 2025 16:20:51 -0500 Subject: [PATCH 1/2] ENH: Validate PET data objects' attributes at instantiation Validate PET data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities. **PET class attributes** Refactor the PET attributes so that only `midframe` and `total_duration` are required and accepted by the constructor. These are the only parameters that are required by the current PET model. Remove `uptake` from the constructor: the PET data class does not need to know the uptake values held across its frames; it is rather the estimator that needs to know about its values so that the iterator can pick the frames following the appropriate sorting. Validate and format attributes so to avoid missing or inconsistent data. Specifically, require the midframe data to have the same length as the number of frames in the data object, and disallow the last midframe value being larget than the total duration. Make the `_compute_uptake_statistic` public so that users can call it. **`from_nii`** function: Refactor the `from_nii` function to accept filenames instead of a mix of filenames (e.g. the PET image sequence and brainmask) and temporal attribute arrays. Honors the name of the function, increases consistency with the dMRI counterpart and allows to offer a uniform API. The only required temporal parameter required by BIDS is the frame time (`FrameTimesStart`). Thus, the temporal attribute JSON (sidecar) file is required to contain that key. The values required to model a PET datast for the purposes of NiFreeze, namely the midframe and total duration values, are computed from the frame time. It is assumed that the frame duration spans entirely the time elapsed between two consecutire time frame values. Refactor and rename the `_compute_frame_duration` function so that it computes and returns the required parameters to instantiate a PET data object. The computation of the relevant temporal values is, thus, done at this place only. Use the `get_data` utils function in `from_nii` to handle automatically the data type when loading the PET data. **`PET.load`** class method: Remove the `PET.load` class method and rely on the `data.__init__.load` function: - If an HDF5 filename is provided, it is assumed that it hosts all necessary information, and the data module `load` function should take of loading all data. - If the provided arguments are NIfTI files plus other data files, the function will call the `pet.PET.from_nii` function. Change the `kwargs` arguments to be able to identify the relevant keyword arguments that are now present in the `from_nii` function. Change accordingly the `PET.load(pet_file, json_file)` call in the PET notebook and the `test_pet_load` test function. **Tests**: Refactor the PET data creation fixture in `conftest.py` to accept the `frame_time` (as it is the only required arguments by BIDS and the one that allows computing the rest) and to return the necessary data. Remove values that are no longer needed (i.e. `total_duration`). Refactor the tests accordingly and increase consistency with the `dmri` data module testing helper functions. Reduces cognitive load and maintenance burden. Add additional object instantiation equality checks: check that objects intantiated through reading NIfTI files equal objects instantiated directly. Check the PET dataset attributes systematically in round trip tests by collecting all named attributes that need to be tested. Modify accordingly the PET model and integration tests. Take advantage of the patch set to make other opinionated choices: - Prefer using the global `setup_random_pet_data` fixture over the local `random_dataset` fixture: it allows to control the parameters of the generated data and increases consistency with the practice adopted across the dMRI dataset tests. Remove the `random_dataset` fixture. - Prefer using `assert np.allclose` over `np.testing.assert_array_equal` for the sake of consistency **Dependencies** Require `attrs>24.1.0` so that `attrs.Converter` can be used. Documentation: https://www.attrs.org/en/25.4.0/api.html#converters --- docs/notebooks/pet_motion_estimation.ipynb | 4 +- pyproject.toml | 2 +- src/nifreeze/data/__init__.py | 2 +- src/nifreeze/data/pet.py | 393 ++++++++++--- test/conftest.py | 14 +- test/test_data.py | 10 +- test/test_data_pet.py | 644 +++++++++++++++++---- test/test_estimator.py | 30 +- test/test_integration_pet.py | 88 ++- test/test_model_pet.py | 65 ++- 10 files changed, 987 insertions(+), 265 deletions(-) diff --git a/docs/notebooks/pet_motion_estimation.ipynb b/docs/notebooks/pet_motion_estimation.ipynb index dbef481d3..ba528406d 100644 --- a/docs/notebooks/pet_motion_estimation.ipynb +++ b/docs/notebooks/pet_motion_estimation.ipynb @@ -10,7 +10,7 @@ "from os import getenv\n", "from pathlib import Path\n", "\n", - "from nifreeze.data.pet import PET\n", + "from nifreeze.data.pet import from_nii\n", "\n", "# Install test data from gin.g-node.org:\n", "# $ datalad install -g https://gin.g-node.org/nipreps-data/tests-nifreeze.git\n", @@ -29,7 +29,7 @@ " DATA_PATH / \"pet_data\" / \"sub-02\" / \"ses-baseline\" / \"pet\" / \"sub-02_ses-baseline_pet.json\"\n", ")\n", "\n", - "pet_dataset = PET.load(pet_file, json_file)" + "pet_dataset = from_nii(pet_file, temporal_file=json_file)" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 78d83e2d9..9fb042f68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ license = "Apache-2.0" requires-python = ">=3.10" dependencies = [ - "attrs>=20.1.0", + "attrs>=24.1.0", "dipy>=1.5.0", "joblib", "nipype>=1.5.1,<2.0", diff --git a/src/nifreeze/data/__init__.py b/src/nifreeze/data/__init__.py index f9d1b245e..8249dcf43 100644 --- a/src/nifreeze/data/__init__.py +++ b/src/nifreeze/data/__init__.py @@ -76,7 +76,7 @@ def load( from nifreeze.data.dmri import from_nii as dmri_from_nii return dmri_from_nii(filename, brainmask_file=brainmask_file, **kwargs) - elif {"frame_time", "frame_duration"} & set(kwargs): + elif {"temporal_file"} & set(kwargs): from nifreeze.data.pet import from_nii as pet_from_nii return pet_from_nii(filename, brainmask_file=brainmask_file, **kwargs) diff --git a/src/nifreeze/data/pet.py b/src/nifreeze/data/pet.py index 7e06fd913..f532f9923 100644 --- a/src/nifreeze/data/pet.py +++ b/src/nifreeze/data/pet.py @@ -28,7 +28,7 @@ from collections import namedtuple from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Tuple import attrs import h5py @@ -39,20 +39,227 @@ from nitransforms.resampling import apply from typing_extensions import Self -from nifreeze.data.base import BaseDataset, _cmp, _data_repr -from nifreeze.utils.ndimage import load_api +from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim +from nifreeze.utils.ndimage import get_data, load_api + +ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None" +"""PET initialization array attribute absence error message.""" + +ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = ( + "PET '{attribute}' must be a numeric homogeneous array-like object." +) +"""PET initialization array attribute object error message.""" + +ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array." +"""PET initialization array attribute ndim error message.""" + +ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG = """\ +PET '{attribute}' length does not match number of frames: \ +expected {n_frames} values, found {attr_len}.""" +"""PET attribute shape mismatch error message.""" + +TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG = """\ +PET 'total_duration' cannot be smaller than last 'midframe' value: \ +found {total_duration} and {last_midframe}.""" +"""PET attribute inconsistency error message.""" + +SCALAR_ATTRIBUTE_ERROR_MSG = ( + "PET '{attribute}' must be a numeric or single-element sequence object." +) +"""PET initialization scalar attribute object error message.""" + +TEMPORAL_FILE_KEY_ERROR_MSG = "{key} key not found in temporal file" +"""PET temporal file key error message.""" + +FRAME_TIME_START_KEY = "FrameTimesStart" +"""PET frame time start key.""" + + +def format_scalar_like(value: Any, attr: attrs.Attribute) -> float: + """Convert ``value`` to a scalar. + + Accepts: + - :obj:`float` or :obj:`int` (but rejects :obj:`bool`) + - Numpy scalar (:obj:`~numpy.generic`, e.g. :obj:`~numpy.floating`, :obj:`~numpy.integer`) + - :obj:`~numpy.ndarray` of size 1 + - :obj:`list`/:obj:`tuple` of length 1 + + This function is intended for use as an attrs-style formatter. + + Parameters + ---------- + value : :obj:`Any` + The value to format. + attr : :obj:`~attrs.Attribute` + The attribute being initialized; ``attr.name`` is used in the error message. + + Returns + ------- + formatted : :obj:`float` + The formatted value. + + Raises + ------ + exc:`TypeError` + If the input cannot be converted to a scalar. + exc:`ValueError` + If the value is ``None``, is of type :obj:`bool` or has not size/length 1. + """ + + if value is None: + raise ValueError(ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name)) + + # Reject bool explicitly (bool is subclass of int) + if isinstance(value, bool): + raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name)) + + # Numpy scalar (np.generic) or numpy 0-d array + if np is not None and isinstance(value, np.generic): + return float(value.item()) + + # Numpy ndarray (ndarray) + if np is not None and isinstance(value, np.ndarray): + if value.size != 1: + raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name)) + return float(value.ravel()[0]) + + # List/tuple with single element + if isinstance(value, (list, tuple)): + if len(value) != 1: + raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name)) + return float(value[0]) + + # Plain int/float (but not bool) + if isinstance(value, (int, float)): + return float(value) + + # Fallback: try to use .item() if present + item = getattr(value, "item", None) + if callable(item): + try: + return float(item()) + except Exception: + pass + + raise TypeError(f"Cannot convert {type(value)!r} to float") + + +def format_array_like(value: Any, attr: attrs.Attribute) -> np.ndarray: + """Convert ``value`` to a :obj:`~numpy.ndarray`. + + This function is intended for use as an attrs-style formatter. + + Parameters + ---------- + value : :obj:`Any` + The value to format. + attr : :obj:`~attrs.Attribute` + The attribute being initialized; ``attr.name`` is used in the error message. + + Returns + ------- + formatted : :obj:`~numpy.ndarray` + The formatted value. + + Raises + ------ + exc:`TypeError` + If the input cannot be converted to a float :obj:`~numpy.ndarray`. + exc:`ValueError` + If the value is ``None``. + """ + + if value is None: + raise ValueError(ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name)) + + try: + formatted = np.asarray(value, dtype=float) + except (TypeError, ValueError) as exc: + # Conversion failed (e.g. nested ragged objects, non-numeric) + raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) from exc + + return formatted + + +def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None: + """Strict validator to ensure an attribute is a 1D NumPy array. + + Enforces that ``value`` has exactly one dimension (``value.ndim == 1``). + + This function is intended for use as an attrs-style validator. + + Parameters + ---------- + inst : :obj:`~nifreeze.data.pet.PET` + The instance being validated (unused; present for validator signature). + attr : :obj:`~attrs.Attribute` + The attribute being validated; ``attr.name`` is used in the error message. + value : :obj:`Any` + The value to validate. + + Raises + ------ + exc:`ValueError` + If the value is not 1D. + """ + + if not _has_ndim(value, 1): + raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name)) @attrs.define(slots=True) class PET(BaseDataset[np.ndarray]): """Data representation structure for PET data.""" - midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) + midframe: np.ndarray = attrs.field( + default=None, + repr=_data_repr, + eq=attrs.cmp_using(eq=_cmp), + converter=attrs.Converter(format_array_like, takes_field=True), # type: ignore + validator=validate_1d_array, + ) """A (N,) numpy array specifying the midpoint timing of each sample or frame.""" - total_duration: float = attrs.field(default=None, repr=True) + total_duration: float = attrs.field( + default=None, + repr=True, + converter=attrs.Converter(format_scalar_like, takes_field=True), # type: ignore + validator=attrs.validators.optional(attrs.validators.instance_of(float)), + ) """A float representing the total duration of the dataset.""" - uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) - """A (N,) numpy array specifying the uptake value of each sample or frame.""" + + def __attrs_post_init__(self) -> None: + """Enforce presence and basic consistency of PET data fields at + instantiation time. + + Specifically, the length of the frame_time and uptake attributes must + match the last dimension of the data (number of frames). + + Computes the values for the private attributes. + """ + + def _check_attr_vol_length_match( + _attr_name: str, _value: np.ndarray | None, _n_frames: int + ) -> None: + if _value is not None and len(_value) != _n_frames: + raise ValueError( + ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG.format( + attribute=_attr_name, + n_frames=_n_frames, + attr_len=len(_value), + ) + ) + + n_frames = int(self.dataobj.shape[-1]) + _check_attr_vol_length_match("midframe", self.midframe, n_frames) + + # Ensure that the total duration is larger than last midframe + if self.total_duration <= self.midframe[-1]: + raise ValueError( + TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG.format( + total_duration=self.total_duration, + last_midframe=self.midframe[-1], + ) + ) def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]: return (self.midframe[idx],) @@ -183,46 +390,38 @@ def from_filename(cls, filename: Path | str) -> Self: data = {k: np.asanyarray(v) for k, v in root.items() if not k.startswith("_")} return cls(**data) - @classmethod - def load( - cls, filename: Path | str, json_file: Path | str, brainmask_file: Path | str | None = None - ) -> Self: - """Load PET data.""" - filename = Path(filename) - if filename.name.endswith(".h5"): - return cls.from_filename(filename) - - img = load_api(filename, SpatialImage) - retval = cls( - dataobj=img.get_fdata(dtype="float32"), - affine=img.affine, - ) - - # Load metadata - with open(json_file, "r") as f: - metadata = json.load(f) - - frame_duration = np.array(metadata["FrameDuration"]) - frame_times_start = np.array(metadata["FrameTimesStart"]) - midframe = frame_times_start + frame_duration / 2 - - retval.midframe = midframe - retval.total_duration = float(frame_times_start[-1] + frame_duration[-1]) + def to_nifti( + self, + filename: Path | str | None = None, + write_hmxfms: bool = False, + order: int = 3, + ) -> nb.nifti1.Nifti1Image: + """ + Export the PET object to disk (NIfTI, temporal attribute files). - assert len(retval.midframe) == retval.dataobj.shape[-1] + Parameters + ---------- + filename : :obj:`os.pathlike`, optional + The output NIfTI file path. + write_hmxfms : :obj:`bool`, optional + If ``True``, the head motion affines will be written out to filesystem + with BIDS' X5 format. + order : :obj:`int`, optional + The interpolation order to use when resampling the data. - if brainmask_file: - mask = load_api(brainmask_file, SpatialImage) - retval.brainmask = np.asanyarray(mask.dataobj) + Returns + ------- + :obj:`~nibabel.nifti1.Nifti1Image` + NIfTI image written to disk. + """ - return retval + return to_nifti(self, filename=filename, write_hmxfms=write_hmxfms, order=order) def from_nii( filename: Path | str, - frame_time: np.ndarray | list[float], + temporal_file: Path | str, brainmask_file: Path | str | None = None, - frame_duration: np.ndarray | list[float] | None = None, ) -> PET: """ Load PET data from NIfTI, creating a PET object with appropriate metadata. @@ -231,6 +430,7 @@ def from_nii( ---------- filename : :obj:`os.pathlike` The NIfTI file. +<<<<<<< HEAD frame_time : :obj:`~numpy.ndarray` or :obj:`list` of :obj:`float` The start times of each frame relative to the beginning of the acquisition. brainmask_file : :obj:`os.pathlike`, optional @@ -240,6 +440,14 @@ def from_nii( The duration of each frame. If :obj:`None`, it is derived by the difference of consecutive frame times, defaulting the last frame to match the second-last. +======= + temporal_file : :obj:`os.pathlike` + A JSON file containing temporal data. It must at least contain + ``frame_time`` data. + brainmask_file : :obj:`os.pathlike`, optional + A brainmask NIfTI file. If provided, will be loaded and + stored in the returned dataset. +>>>>>>> 387dfe63 (ENH: Validate PET data objects' attributes at instantiation) Returns ------- @@ -254,62 +462,103 @@ def from_nii( """ filename = Path(filename) - # Load from NIfTI + + # 1) Load a NIfTI img = load_api(filename, SpatialImage) - data = img.get_fdata(dtype=np.float32) - pet_obj = PET( - dataobj=data, + fulldata = get_data(img) + + # 2) Load the temporal data + with open(temporal_file, "r") as f: + temporal_attrs = json.load(f) + + frame_time = temporal_attrs.get(FRAME_TIME_START_KEY, None) + if frame_time is None: + raise RuntimeError(TEMPORAL_FILE_KEY_ERROR_MSG.format(key=FRAME_TIME_START_KEY)) + + # 3) If a brainmask_file was provided, load it + brainmask_data = None + if brainmask_file is not None: + mask_img = load_api(brainmask_file, SpatialImage) + brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool) + + # 4) Compute temporal attributes + midframe, total_duration = _compute_temporal_markers(np.asarray(frame_time)) + + # 5) Create and return the PET instance + return PET( + dataobj=fulldata, affine=img.affine, + brainmask=brainmask_data, + midframe=midframe, + total_duration=total_duration, ) - pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum) - # Convert to a float32 numpy array and zero out the earliest time - frame_time_arr = np.array(frame_time, dtype=np.float32) - frame_time_arr -= frame_time_arr[0] - pet_obj.midframe = frame_time_arr +def _compute_temporal_markers(frame_time: np.ndarray) -> Tuple[np.ndarray, float]: + """Compute the frame temporal markers from the frame time values. - # If the user doesn't provide frame_duration, we derive it: - if frame_duration is None: - durations = _compute_frame_duration(pet_obj.midframe) - else: - durations = np.array(frame_duration, dtype=np.float32) + Computes the midframe times and the total duration following the principles + detailed below. - # Set total_duration and shift frame_time to the midpoint - pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1]) - pet_obj.midframe = frame_time_arr + 0.5 * durations + Let :math:`K` be the number of frames and :math:`t_{k}` be the :math:`k`-th + (start) frame time. For each frame :math:`k`, the frame duration + :math:`d_{k}` is defined as the difference between consecutive frame times: - # If a brain mask is provided, load and attach - if brainmask_file is not None: - mask_img = load_api(brainmask_file, SpatialImage) - pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) + .. math:: + d_{k} = t_{k+1} - t_{k} + + If necessary, the last frame duration is set to the value of the second to + last frame to match the appropriate dimensionality in this implementation. - return pet_obj + Per-frame midpoints :math:`m_{k}` are computed as: + .. math:: + m_{k} = t_{k} + \\frac{d_k}{2} -def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray: - """Compute the frame duration from the midframe values. + The total duration :math:`D` of the acquisition is a scalar computed as the + sum of the frame durations: + + .. math:: + D = \\sum_{k=1}^{K} d_{k} + + or, equivalently, the difference between the last frame start and its + duration once the frame times have been time-origin shifted: + + .. math:: + D = t_{K} - d_{K} + + Frame times are time-origin shifted (i.e. the earliest time is zeroed out) + if not already done at the beginning of the process for the sake of + simplicity. Parameters ---------- - midframe : :obj:`~numpy.ndarray` - Midframe time values. + frame_time : :obj:`~numpy.ndarray` + Frame time values. Returns ------- - durations : :obj:`~numpy.ndarray` - Frame duration. + :obj:`tuple` + Midpoint timing of each frame and total duration """ + # Time-origin shift: zero out the earliest time if necessary + # Flatten the array in case it is not a 1D array + if not np.isclose(frame_time.ravel()[0], 0): + frame_time -= frame_time.flat[0] + # If shape is e.g. (N,), then we can do - durations = np.diff(midframe) - if len(durations) == (len(midframe) - 1): - durations = np.append(durations, durations[-1]) # last frame same as second-last + frame_duration = np.diff(frame_time) + if len(frame_duration) == (len(frame_time) - 1): + frame_duration = np.append(frame_duration, frame_duration[-1]) # last frame same as second-last + + midframe = frame_time + frame_duration / 2 + total_duration = float(frame_time[-1] + frame_duration[-1]) - return durations + return midframe, total_duration -def _compute_uptake_statistic(data: np.ndarray, stat_func: Callable = np.sum): +def compute_uptake_statistic(data: np.ndarray, stat_func: Callable[..., np.ndarray] = np.sum): """Compute a statistic over all voxels for each frame on a PET sequence. Assumes the last dimension corresponds to the number of frames in the diff --git a/test/conftest.py b/test/conftest.py index 4b9a091ce..7c79894f7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -323,10 +323,17 @@ def setup_random_pet_data(request): n_frames = 5 vol_size = (4, 4, 4) - midframe = np.arange(n_frames, dtype=np.float32) + 1 - total_duration = float(n_frames + 1) + frame_time = np.arange(n_frames, dtype=np.float32) + 1 if marker: - n_frames, vol_size, midframe, total_duration = marker.args + n_frames, vol_size, frame_time = marker.args + + frame_time = np.asarray(frame_time) + frame_time -= frame_time[0] + frame_duration = np.diff(frame_time) + if len(frame_duration) == (len(frame_time) - 1): + frame_duration = np.append(frame_duration, frame_duration[-1]) + midframe = frame_time + frame_duration / 2 + total_duration = float(frame_time[-1] + frame_duration[-1]) rng = request.node.rng @@ -339,6 +346,7 @@ def setup_random_pet_data(request): pet_dataobj, affine, brainmask_dataobj, + frame_time, midframe, total_duration, ) diff --git a/test/test_data.py b/test/test_data.py index 15a7ecb1d..bd69c756c 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -21,6 +21,7 @@ # https://www.nipreps.org/community/licensing/ # +import json import os from typing import Optional @@ -279,6 +280,11 @@ def test_load_pet_from_nii(monkeypatch, tmp_path): nb.save(img, fname) nb.save(mask_img, mask) + temporal_fname = tmp_path / "temporal.json" + temporal_data = {"frame_time": np.ones(4).tolist()} + with temporal_fname.open("w", encoding="utf-8") as f: + json.dump(temporal_data, f, ensure_ascii=False, indent=2, sort_keys=True) + called = {} sentinel = object() @@ -290,9 +296,9 @@ def dummy_from_nii(filename, brainmask_file=None, **kwargs): monkeypatch.setattr(pet, "from_nii", dummy_from_nii) - retval = data.load(fname, brainmask_file=mask, frame_time=np.zeros((4,))) + retval = data.load(fname, brainmask_file=mask, temporal_file=temporal_fname) assert retval is sentinel assert called["filename"] == fname assert called["brainmask_file"] == mask - assert "frame_time" in called["kwargs"] + assert "temporal_file" in called["kwargs"] diff --git a/test/test_data_pet.py b/test/test_data_pet.py index 1e40cf8a9..b336c07cc 100644 --- a/test/test_data_pet.py +++ b/test/test_data_pet.py @@ -22,36 +22,57 @@ # import json +import math from pathlib import Path +from typing import Any, Type +import attrs import nibabel as nb import numpy as np import pytest from nitransforms.linear import Affine -from nifreeze.data.pet import PET, _compute_frame_duration, _compute_uptake_statistic, from_nii +from nifreeze.data import load as nifreeze_load +from nifreeze.data.pet import ( + ARRAY_ATTRIBUTE_NDIM_ERROR_MSG, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ATTRIBUTE_ABSENCE_ERROR_MSG, + ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG, + FRAME_TIME_START_KEY, + PET, + SCALAR_ATTRIBUTE_ERROR_MSG, + TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG, + TEMPORAL_FILE_KEY_ERROR_MSG, + _compute_temporal_markers, + compute_uptake_statistic, + format_array_like, + format_scalar_like, + from_nii, + validate_1d_array, +) from nifreeze.utils.ndimage import load_api -@pytest.fixture -def random_dataset(setup_random_pet_data) -> PET: - """Create a PET dataset with random data for testing.""" - - ( - pet_dataobj, - affine, - brainmask_dataobj, - midframe, - total_duration, - ) = setup_random_pet_data - - return PET( - dataobj=pet_dataobj, - affine=affine, - brainmask=brainmask_dataobj, - midframe=midframe, - total_duration=total_duration, - ) +def _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj): + pet = nb.Nifti1Image(pet_dataobj, affine) + brainmask = nb.Nifti1Image(brainmask_dataobj, affine) + + return pet, brainmask + + +def _serialize_pet_data(pet, brainmask, frame_time, _tmp_path): + pet_fname = _tmp_path / "pet.nii.gz" + brainmask_fname = _tmp_path / "brainmask.nii.gz" + temporal_fname = _tmp_path / "temporal.json" + + nb.save(pet, pet_fname) + nb.save(brainmask, brainmask_fname) + + temporal_data = {FRAME_TIME_START_KEY: frame_time.tolist()} + with temporal_fname.open("w", encoding="utf-8") as f: + json.dump(temporal_data, f, ensure_ascii=False, indent=2, sort_keys=True) + + return pet_fname, brainmask_fname, temporal_fname @pytest.fixture @@ -64,17 +85,284 @@ def random_nifti_file(tmp_path, setup_random_uniform_spatial_data) -> Path: @pytest.mark.parametrize( - "midframe, expected", + "attr_name, value, expected_exc, expected_msg", [ - ([1.0, 4.0], [3.0, 3.0]), - ([0.0, 5.0, 9.0, 12.0], [5.0, 4.0, 3.0, 3.0]), + ("any_name", None, ValueError, ATTRIBUTE_ABSENCE_ERROR_MSG), + ( + "any_name", + [[10.0], [20.0, 30.0], [40.0], [50.0]], + TypeError, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ), # Ragged + ( + "any_name", + np.array([[-0.9], [0.06, 0.12], [0.27], [0.08]], dtype=object), + TypeError, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ), # Ragged ], ) -def test_compute_frame_duration(midframe, expected): - midframe = np.array(midframe) - expected = np.array(expected) - durations = _compute_frame_duration(midframe) - np.testing.assert_allclose(durations, expected) +def test_format_array_like_errors(attr_name, value, expected_exc, expected_msg): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + with pytest.raises(expected_exc, match=expected_msg.format(attribute=attr_name)): + format_array_like(value, dummy_attr) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "attr_name, value, expected_exc, expected_msg", + [ + ("any_name", None, ValueError, ATTRIBUTE_ABSENCE_ERROR_MSG), + ("any_name", True, ValueError, SCALAR_ATTRIBUTE_ERROR_MSG), + ("any_name", (2, 2), ValueError, SCALAR_ATTRIBUTE_ERROR_MSG), + ("any_name", np.asarray([1.0, 2.0]), ValueError, SCALAR_ATTRIBUTE_ERROR_MSG), + ], +) +def test_format_scalar_errors(attr_name, value, expected_exc, expected_msg): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + + with pytest.raises(expected_exc, match=expected_msg.format(attribute=attr_name)): + format_scalar_like(value, dummy_attr) # type: ignore[arg-type] + + +@pytest.mark.parametrize("value", [[1.0, 2.0, 3.0, 4.0], (1.0, 2.0, 3.0, 4.0)]) +@pytest.mark.parametrize("attr_name", ("midframe",)) +def test_format_array_like(value, attr_name): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + + obtained = format_array_like(value, dummy_attr) + assert isinstance(obtained, np.ndarray) + assert obtained.shape == np.asarray(value).shape + assert np.allclose(obtained, np.asarray(value)) + + +@pytest.mark.parametrize("value", [1.0, [2.0], (3.0,), np.array(4.0), np.array([5.0])]) +def test_format_scalar_like(value): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {"total_duration": attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + + obtained = format_scalar_like(value, dummy_attr) + assert isinstance(obtained, float) + assert np.allclose(obtained, np.asarray(value)) + + +@pytest.mark.parametrize("attr_name, value", [("my_attr", np.asarray([1.0, 2.0, 3.0, 4.0]))]) +@pytest.mark.parametrize("extra_dimensions", (1, 2)) +@pytest.mark.parametrize("transpose", (True, False)) +def test_validate_1d_arr_errors( + request, monkeypatch, attr_name, value, extra_dimensions, transpose +): + def _add_extra_dim(_rng, _attr_name, _extra_dimensions, _transpose, _value): + _arr = np.concatenate( + [ + _value[:, None], + rng.random((_value.size, _extra_dimensions)), + ], + axis=1, + ) + _arr = _arr.T if _transpose else _arr + return _arr + + rng = request.node.rng + _value = _add_extra_dim(rng, attr_name, extra_dimensions, transpose, value) + + monkeypatch.setattr(PET, "__init__", lambda self, *a, **k: None) + + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + new_attr = dummy_attr_cls.__attrs_attrs__[0] + + # Replace PET's attribute metadata with just the new_attr. + # attrs.fields() reads PET.__attrs_attrs__ at runtime, so setting this + # effectively "removes" previous attributes and leaves only our single one. + monkeypatch.setattr(PET, "__attrs_attrs__", (new_attr,), raising=False) + # Also set a matching annotation dict + monkeypatch.setattr(PET, "__annotations__", {attr_name: object()}, raising=False) + + # Instantiate and obtain the attrs.Attribute from PET + inst = PET() + dummy_attr = attrs.fields(PET)[0] + # assert isinstance(dummy_attr, attrs.Attribute) + # assert dummy_attr.name == attr_name + + with pytest.raises( + ValueError, + match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr_name), + ): + validate_1d_array(inst, dummy_attr, _value) + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +@pytest.mark.parametrize("attr_name", ("midframe",)) +@pytest.mark.parametrize("extra_dimensions", (1, 2)) +@pytest.mark.parametrize("transpose", (True, False)) +def test_pet_instantiation_attribute_validate_1d_arr_errors( + request, setup_random_pet_data, attr_name, extra_dimensions, transpose +): + def _add_extra_dim(_rng, _attr_name, _extra_dimensions, _transpose, **_kwargs): + _arr = np.concatenate( + [ + _kwargs[_attr_name][:, None], + rng.random((_kwargs[_attr_name].size, _extra_dimensions)), + ], + axis=1, + ) + _kwargs[_attr_name] = _arr.T if _transpose else _arr + return _kwargs + + rng = request.node.rng + pet_dataobj, affine, _, _, midframe, total_duration = setup_random_pet_data + + attrs_dict = dict( + midframe=midframe, + total_duration=total_duration, + ) + _attrs_dict = _add_extra_dim(rng, attr_name, extra_dimensions, transpose, **attrs_dict) + + with pytest.raises( + ValueError, + match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr_name), + ): + PET(dataobj=pet_dataobj, affine=affine, **_attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize("attr_name", ("midframe", "total_duration")) +def test_pet_instantiation_attribute_convert_absence_errors( + setup_random_uniform_spatial_data, + attr_name, +): + data, affine = setup_random_uniform_spatial_data + + n_frames = data.shape[-1] + # Create a dict with default valid attribute values + attrs_dict: dict[str, np.ndarray | float | None] = dict( + midframe=np.ones(n_frames, dtype=np.float32), + total_duration=1.0, + ) + + # Override only the attribute under test + attrs_dict[attr_name] = None + + with pytest.raises(ValueError, match=ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr_name)): + PET(dataobj=data, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize("attr_name", ("midframe",)) +@pytest.mark.parametrize( + "value", + [ + ([[10.0], [20.0, 30.0], [40.0], [50.0]]), # Ragged + (np.array([[-0.9], [0.06, 0.12], [0.27], [0.08]], dtype=object)), # Ragged + ], +) +def test_pet_instantiation_attribute_convert_object_errors( + setup_random_uniform_spatial_data, attr_name, value +): + data, affine = setup_random_uniform_spatial_data + + n_frames = data.shape[-1] + # Create a dict with some valid attributes + attrs_dict = dict( + midframe=np.ones(n_frames, dtype=np.float32), + total_duration=n_frames + 1, + ) + + # Override only the attribute under test + attrs_dict[attr_name] = value + + with pytest.raises( + TypeError, match=ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr_name) + ): + PET(dataobj=data, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +@pytest.mark.parametrize("attr_name", ("midframe",)) +@pytest.mark.parametrize( + ("extra_volume_count", "extra_attribute_count"), + [(1, 0), (2, 0), (2, 1), (0, 1), (0, 2), (1, 2)], +) +def test_pet_instantiation_attribute_vol_mismatch_error( + setup_random_pet_data, attr_name, extra_volume_count, extra_attribute_count +): + pet_dataobj, affine, _, _, midframe, total_duration = setup_random_pet_data + + n_frames = int(pet_dataobj.shape[-1]) + attrs_dict = dict( + midframe=midframe, + total_duration=total_duration, + ) + + # Add extra volumes: simply concatenate the last volume + if extra_volume_count: + extra_dwi_dataobj = np.tile(pet_dataobj[..., -1:], (1, extra_volume_count)) + pet_dataobj = np.concatenate((pet_dataobj, extra_dwi_dataobj), axis=-1) + n_frames = int(pet_dataobj.shape[-1]) + # Add extra values to attribute: simply concatenate the attribute + if extra_attribute_count: + base = attrs_dict[attr_name] + extra_vals = np.repeat(base[-1], extra_attribute_count) + attrs_dict[attr_name] = np.concatenate((base, extra_vals)) + + attr_val = attrs_dict[attr_name] + + with pytest.raises( + ValueError, + match=ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG.format( + attribute=attr_name, n_frames=n_frames, attr_len=len(attr_val) + ), + ): + PET(dataobj=pet_dataobj, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +@pytest.mark.random_pet_data(3, (2, 2, 2), np.asarray([1.0, 2.0, 3.0])) +@pytest.mark.parametrize("attr_name, excess_value", + [("midframe", 1.0), ("midframe", 2.0), ("total_duration", 0.0), ("total_duration", -1.0), ("total_duration", -2.0)], +) +def test_pet_instantiation_attribute_inconsistency_error(setup_random_pet_data, attr_name, excess_value): + pet_dataobj, affine, _, _, midframe, total_duration = setup_random_pet_data + + if attr_name == "midframe": + midframe[-1] = total_duration + excess_value + elif attr_name == "total_duration": + total_duration = midframe[-1] + excess_value + + attrs_dict = dict( + midframe=midframe, + total_duration=total_duration, + ) + + with pytest.raises( + ValueError, + match=TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG.format( + total_duration=total_duration, last_midframe=midframe[-1] + ), + ): + PET(dataobj=pet_dataobj, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "frame_time, expected_midframe, expected_total_duration", + [ + ([1.0, 4.0], [1.5, 4.5], 6.0), + ([0.0, 5.0, 9.0, 12.0], [2.5, 7.0, 10.5, 13.5], 15.0), + ], +) +def test_compute_temporal_markers(frame_time, expected_midframe, expected_total_duration): + frame_time = np.array(frame_time) + expected_midframe = np.array(expected_midframe) + midframe, total_duration = _compute_temporal_markers(frame_time) + np.testing.assert_allclose(midframe, expected_midframe) + assert np.isclose(total_duration, expected_total_duration) @pytest.mark.parametrize("stat_func", (np.sum, np.mean, np.std)) @@ -83,131 +371,243 @@ def test_compute_uptake_statistic(stat_func): data = rng.random((4, 4, 4, 5), dtype=np.float32) expected = stat_func(data.reshape(-1, data.shape[-1]), axis=0) - obtained = _compute_uptake_statistic(data, stat_func=stat_func) + obtained = compute_uptake_statistic(data, stat_func=stat_func) np.testing.assert_array_equal(obtained, expected) -@pytest.mark.parametrize( - ("brainmask_file", "frame_time", "frame_duration"), - [ - (None, [0.0, 5.0], [5.0, 5.0]), - (None, [10.0, 15.0], [5.0, 5.0]), - ("mask.nii.gz", [0.0, 5.0], [5.0, 5.0]), - ("mask.nii.gz", [0.0, 5.0], None), - ], -) -def test_from_nii(tmp_path, random_nifti_file, brainmask_file, frame_time, frame_duration): - filename = random_nifti_file - img = load_api(filename, nb.Nifti1Image) - if brainmask_file: - mask_data = np.ones(img.get_fdata().shape[:-1], dtype=bool) - mask_img = nb.Nifti1Image(mask_data.astype(np.uint8), img.affine) - mask_img.to_filename(brainmask_file) - - pet_obj = from_nii( - filename, - brainmask_file=brainmask_file, - frame_time=frame_time, - frame_duration=frame_duration, - ) - assert isinstance(pet_obj, PET) - assert pet_obj.dataobj.shape == img.get_fdata().shape - np.testing.assert_array_equal(pet_obj.affine, img.affine) +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_from_nii_errors(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + pet_fname = tmp_path / "pet.nii.gz" + brainmask_fname = tmp_path / "brainmask.nii.gz" + temporal_fname = tmp_path / "temporal.json" + + nb.save(pet, pet_fname) + nb.save(brainmask, brainmask_fname) + + # Check frame time + temporal_data = {"any_key": frame_time.tolist()} + with temporal_fname.open("w", encoding="utf-8") as f: + json.dump(temporal_data, f, ensure_ascii=False, indent=2, sort_keys=True) + + with pytest.raises( + RuntimeError, match=TEMPORAL_FILE_KEY_ERROR_MSG.format(key=FRAME_TIME_START_KEY) + ): + from_nii( + pet_fname, + temporal_fname, + brainmask_file=brainmask_fname, + ) + + +@pytest.mark.random_pet_data(3, (2, 2, 2), np.asarray([1.0, 4.0, 6.0])) +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_from_nii(tmp_path, setup_random_pet_data): + from nifreeze.data.base import BaseDataset - # Convert to a float32 numpy array and zero out the earliest time - frame_time_arr = np.array(frame_time, dtype=np.float32) - frame_time_arr -= frame_time_arr[0] - if frame_duration is None: - durations = _compute_frame_duration(frame_time_arr) - else: - durations = np.array(frame_duration, dtype=np.float32) + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data - expected_total_duration = float(frame_time_arr[-1] + durations[-1]) - expected_midframe = frame_time_arr + 0.5 * durations + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) - np.testing.assert_allclose(pet_obj.midframe, expected_midframe) - assert pet_obj.total_duration == expected_total_duration + pet_fname, brainmask_fname, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) - if brainmask_file: - assert pet_obj.brainmask is not None - np.testing.assert_array_equal(pet_obj.brainmask, mask_data) + # Read back using public API + pet_obj_from_nii = from_nii(pet_fname, temporal_fname, brainmask_file=brainmask_fname) + assert isinstance(pet_obj_from_nii, PET) + + attrs_dict: dict[str, np.ndarray | float | None] = dict( + midframe=midframe, + total_duration=total_duration, + ) + + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + # No need to check base class attributes: remove them + base_attrs = [ + a.name + for a in attrs.fields(BaseDataset) + if not a.name.startswith("_") and not a.name.isdigit() + ] + attrs_to_check = [_ for _ in attrs_to_check if _ not in base_attrs] + + for attr_name in attrs_to_check: + val_direct = attrs_dict[attr_name] + val_from_nii = getattr(pet_obj_from_nii, attr_name) + + if val_direct is None or val_from_nii is None: + assert val_direct is None and val_from_nii is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_from_nii.shape + assert np.allclose(val_direct, val_from_nii), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_from_nii), f"{attr_name} values differ" + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_to_nifti(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_to_nifti(tmp_path, random_dataset): out_filename = tmp_path / "random_pet_out.nii.gz" - random_dataset.to_nifti(str(out_filename)) + pet_obj.to_nifti(str(out_filename)) assert out_filename.exists() loaded_img = load_api(str(out_filename), nb.Nifti1Image) - assert np.allclose(loaded_img.get_fdata(), random_dataset.dataobj) - assert np.allclose(loaded_img.affine, random_dataset.affine) + assert np.allclose(loaded_img.get_fdata(), pet_obj.dataobj) + assert np.allclose(loaded_img.affine, pet_obj.affine) units = loaded_img.header.get_xyzt_units() assert units[0] == "mm" -@pytest.mark.parametrize( - ("frame_time", "frame_duration"), - [ - ([0.0, 5.0], [5.0, 5.0]), - ], -) -def test_round_trip(tmp_path, random_nifti_file, frame_time, frame_duration): - filename = random_nifti_file - img = load_api(filename, nb.Nifti1Image) - pet_obj = from_nii(filename, frame_time=frame_time, frame_duration=frame_duration) +@pytest.mark.random_pet_data(2, (2, 2, 2), np.asarray([0.0, 5.0])) +def test_round_trip(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + pet_fname, _, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) + + img = load_api(pet_fname, nb.Nifti1Image) + pet_obj = from_nii(pet_fname, temporal_fname) out_fname = tmp_path / "random_pet_out.nii.gz" pet_obj.to_nifti(out_fname) assert out_fname.exists() loaded_img = load_api(out_fname, nb.Nifti1Image) - np.testing.assert_array_equal(loaded_img.affine, img.affine) + assert np.allclose(loaded_img.affine, img.affine) np.testing.assert_allclose(loaded_img.get_fdata(), img.get_fdata()) units = loaded_img.header.get_xyzt_units() assert units[0] == "mm" -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_pet_set_transform_updates_motion_affines(random_dataset): +def test_equality_operator(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + pet_fname, brainmask_fname, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) + + # Read back using public API + pet_obj_from_nii = from_nii(pet_fname, temporal_fname, brainmask_file=brainmask_fname) + + # Direct instantiation with the same arrays + pet_obj_direct = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) + + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + + # Sanity checks (element-wise) + for attr_name in attrs_to_check: + val_direct = getattr(pet_obj_direct, attr_name) + val_from_nii = getattr(pet_obj_from_nii, attr_name) + + if val_direct is None or val_from_nii is None: + assert val_direct is None and val_from_nii is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_from_nii.shape + assert np.allclose(val_direct, val_from_nii), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_from_nii), f"{attr_name} values differ" + + # Test equality operator + assert pet_obj_direct == pet_obj_from_nii + + # Test equality operator against an instance from HDF5 + hdf5_filename = tmp_path / "test_pet.h5" + pet_obj_from_nii.to_filename(hdf5_filename) + + round_trip_pet_obj = PET.from_filename(hdf5_filename) + + # Symmetric equality + assert pet_obj_from_nii == round_trip_pet_obj + assert round_trip_pet_obj == pet_obj_from_nii + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_pet_set_transform_updates_motion_affines(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) + idx = 2 - data_before = np.copy(random_dataset.dataobj[..., idx]) + data_before = np.copy(pet_obj.dataobj[..., idx]) affine = np.eye(4) - random_dataset.set_transform(idx, affine) + pet_obj.set_transform(idx, affine) - np.testing.assert_allclose(random_dataset.dataobj[..., idx], data_before) - assert random_dataset.motion_affines is not None - assert len(random_dataset.motion_affines) == len(random_dataset) - assert isinstance(random_dataset.motion_affines[idx], Affine) - np.testing.assert_array_equal(random_dataset.motion_affines[idx].matrix, affine) + np.testing.assert_allclose(pet_obj.dataobj[..., idx], data_before) + assert pet_obj.motion_affines is not None + assert len(pet_obj.motion_affines) == len(pet_obj) + assert isinstance(pet_obj.motion_affines[idx], Affine) + assert np.allclose(pet_obj.motion_affines[idx].matrix, affine) - vol, aff, time = random_dataset[idx] - assert aff is random_dataset.motion_affines[idx] + vol, aff, time = pet_obj[idx] + assert aff is pet_obj.motion_affines[idx] -@pytest.mark.random_uniform_spatial_data((2, 2, 2, 2), 0.0, 1.0) -def test_pet_load(request, tmp_path, setup_random_uniform_spatial_data): - data, affine = setup_random_uniform_spatial_data - img = nb.Nifti1Image(data, affine) - fname = tmp_path / "pet.nii.gz" - img.to_filename(fname) +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_pet_load(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data - brainmask_dataobj = request.node.rng.choice([True, False], size=data.shape[:3]).astype( - np.uint8 + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + # Direct instantiation with the same arrays + pet_obj_direct = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, ) - brainmask = nb.Nifti1Image(brainmask_dataobj, affine) - brainmask_fname = tmp_path / "brainmask.nii.gz" - brainmask.to_filename(brainmask_fname) - - json_file = tmp_path / "pet.json" - metadata = { - "FrameDuration": [1.0, 1.0], - "FrameTimesStart": [0.0, 1.0], - } - json_file.write_text(json.dumps(metadata)) - - pet_obj = PET.load(fname, json_file, brainmask_fname) - - assert pet_obj.dataobj.shape == data.shape - assert np.allclose(pet_obj.midframe, [0.5, 1.5]) - assert pet_obj.total_duration == 2.0 - if pet_obj.brainmask is not None: - assert pet_obj.brainmask.shape == brainmask_dataobj.shape + + pet_fname, brainmask_fname, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) + + pet_from_nii_kwargs = {"temporal_file": temporal_fname} + + pet_obj_load = nifreeze_load(pet_fname, brainmask_fname, **pet_from_nii_kwargs) + + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + + # Sanity checks (element-wise) + for attr_name in attrs_to_check: + val_direct = getattr(pet_obj_direct, attr_name) + val_load = getattr(pet_obj_load, attr_name) + + if val_direct is None or val_load is None: + assert val_direct is None and val_load is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_load.shape + assert np.allclose(val_direct, val_load), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_load), f"{attr_name} values differ" diff --git a/test/test_estimator.py b/test/test_estimator.py index f1547bd0a..cc4df583f 100644 --- a/test/test_estimator.py +++ b/test/test_estimator.py @@ -78,13 +78,25 @@ def __getitem__(self, idx): class DummyPETDataset(BaseDataset): - def __init__(self, pet_dataobj, affine, brainmask_dataobj, midframe, total_duration): + def __init__( + self, + pet_dataobj, + affine, + brainmask_dataobj, + frame_time, + uptake, + frame_duration, + midframe, + total_duration, + ): self.dataobj = pet_dataobj self.affine = affine self.brainmask = brainmask_dataobj + self.frame_time = frame_time + self.uptake = uptake + self.frame_duration = frame_duration self.midframe = midframe self.total_duration = total_duration - self.uptake = np.sum(pet_dataobj.reshape(-1, pet_dataobj.shape[-1]), axis=0) def __len__(self): return self.dataobj.shape[-1] @@ -155,11 +167,23 @@ def test_estimator_iterator_index_match( pet_dataobj, affine, brainmask_dataobj, + frame_time, + uptake, + frame_duration, midframe, total_duration, ) = setup_random_pet_data - dataset = DummyPETDataset(pet_dataobj, affine, brainmask_dataobj, midframe, total_duration) + dataset = DummyPETDataset( + pet_dataobj, + affine, + brainmask_dataobj, + frame_time, + uptake, + frame_duration, + midframe, + total_duration, + ) uptake = dataset.uptake kwargs = {"uptake": uptake} else: diff --git a/test/test_integration_pet.py b/test/test_integration_pet.py index 68b2a318e..207de2831 100644 --- a/test/test_integration_pet.py +++ b/test/test_integration_pet.py @@ -21,8 +21,10 @@ # https://www.nipreps.org/community/licensing/ # +import math import types +import attrs import numpy as np import pytest @@ -30,19 +32,11 @@ from nifreeze.estimator import PETMotionEstimator -@pytest.fixture -def random_dataset(setup_random_pet_data) -> PET: - """Create a PET dataset with random data for testing.""" +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 13.0, 17.0, 30.0, 33.0])) +def test_lofo_split_shapes(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data - ( - pet_dataobj, - affine, - brainmask_dataobj, - midframe, - total_duration, - ) = setup_random_pet_data - - return PET( + pet_obj = PET( dataobj=pet_dataobj, affine=affine, brainmask=brainmask_dataobj, @@ -50,31 +44,63 @@ def random_dataset(setup_random_pet_data) -> PET: total_duration=total_duration, ) - -@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0]), 5.0) -def test_lofo_split_shapes(random_dataset, tmp_path): idx = 2 - (train_data, train_times), (test_data, test_time) = random_dataset.lofo_split(idx) - assert train_data.shape[-1] == random_dataset.dataobj.shape[-1] - 1 - np.testing.assert_array_equal(test_data, random_dataset.dataobj[..., idx]) - np.testing.assert_array_equal(train_times, np.delete(random_dataset.midframe, idx)) - assert test_time == random_dataset.midframe[idx] + (train_data, train_times), (test_data, test_time) = pet_obj.lofo_split(idx) + assert train_data.shape[-1] == pet_obj.dataobj.shape[-1] - 1 + np.testing.assert_array_equal(test_data, pet_obj.dataobj[..., idx]) + np.testing.assert_array_equal(train_times, np.delete(pet_obj.midframe, idx)) + assert test_time == pet_obj.midframe[idx] + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0])) +def test_to_from_filename_roundtrip(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(3, (2, 2, 2), np.asarray([1.0, 2.0, 3.0]), 4.0) -def test_to_from_filename_roundtrip(random_dataset, tmp_path): out_file = tmp_path / "petdata" - random_dataset.to_filename(out_file) + pet_obj.to_filename(out_file) assert (tmp_path / "petdata.h5").exists() loaded = PET.from_filename(tmp_path / "petdata.h5") - np.testing.assert_allclose(loaded.dataobj, random_dataset.dataobj) - np.testing.assert_allclose(loaded.affine, random_dataset.affine) - np.testing.assert_allclose(loaded.midframe, random_dataset.midframe) - assert loaded.total_duration == random_dataset.total_duration + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + + # Sanity checks (element-wise) + for attr_name in attrs_to_check: + val_direct = getattr(pet_obj, attr_name) + val_loaded = getattr(loaded, attr_name) + + if val_direct is None or val_loaded is None: + assert val_direct is None and val_loaded is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_loaded.shape + assert np.allclose(val_direct, val_loaded), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_loaded), f"{attr_name} values differ" + + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 15.0, 20.0, 25.0, 30.0])) +def test_pet_motion_estimator_run(monkeypatch, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_pet_motion_estimator_run(random_dataset, monkeypatch): class DummyModel: def __init__(self, dataset, timepoints, xlim): self.dataset = dataset @@ -96,7 +122,7 @@ def run(self, cwd=None): monkeypatch.setattr("nifreeze.estimator.Registration", DummyRegistration) estimator = PETMotionEstimator(None) - affines = estimator.run(random_dataset) - assert len(affines) == len(random_dataset) + affines = estimator.run(pet_obj) + assert len(affines) == len(pet_obj) for mat in affines: np.testing.assert_array_equal(mat, np.eye(4)) diff --git a/test/test_model_pet.py b/test/test_model_pet.py index 6081a07d5..1c6bdcb4f 100644 --- a/test/test_model_pet.py +++ b/test/test_model_pet.py @@ -28,19 +28,11 @@ from nifreeze.model.pet import PETModel -@pytest.fixture -def random_dataset(setup_random_pet_data) -> PET: - """Create a PET dataset with random data for testing.""" - - ( - pet_dataobj, - affine, - brainmask_dataobj, - midframe, - total_duration, - ) = setup_random_pet_data - - return PET( +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0])) +def test_petmodel_fit_predict(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( dataobj=pet_dataobj, affine=affine, brainmask=brainmask_dataobj, @@ -48,13 +40,10 @@ def random_dataset(setup_random_pet_data) -> PET: total_duration=total_duration, ) - -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_fit_predict(random_dataset): model = PETModel( - dataset=random_dataset, - timepoints=random_dataset.midframe, - xlim=random_dataset.total_duration, + dataset=pet_obj, + timepoints=pet_obj.midframe, + xlim=pet_obj.total_duration, smooth_fwhm=0, thresh_pct=0, ) @@ -64,20 +53,40 @@ def test_petmodel_fit_predict(random_dataset): assert model.is_fitted # Predict at a specific timepoint - vol = model.fit_predict(random_dataset.midframe[2]) + vol = model.fit_predict(pet_obj.midframe[2]) assert vol is not None - assert vol.shape == random_dataset.shape3d - assert vol.dtype == random_dataset.dataobj.dtype + assert vol.shape == pet_obj.shape3d + assert vol.dtype == pet_obj.dataobj.dtype + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([1.0, 2.0, 3.0, 4.0, 5.0])) +def test_petmodel_invalid_init1(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_invalid_init(random_dataset): with pytest.raises(TypeError): - PETModel(dataset=random_dataset) + PETModel(dataset=pet_obj) + + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([5.0, 10.0, 15.0, 20.0, 25.0])) +def test_petmodel_time_check(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_time_check(random_dataset): bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32) with pytest.raises(ValueError): - PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0) + PETModel(dataset=pet_obj, timepoints=bad_times, xlim=60.0) From ecf6baba80225c3a0dbf0b6e642c53da57b0b7c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 6 Dec 2025 10:46:40 -0500 Subject: [PATCH 2/2] WIP: Override `to_nifti` Override `to_nifti` to be able to serialize a BIDS-compatible PET datasets. Requires reconstructing `frame_time` from `midframe` and `total_duration`. --- src/nifreeze/data/pet.py | 60 ++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/src/nifreeze/data/pet.py b/src/nifreeze/data/pet.py index f532f9923..b89d41598 100644 --- a/src/nifreeze/data/pet.py +++ b/src/nifreeze/data/pet.py @@ -40,6 +40,7 @@ from typing_extensions import Self from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim +from nifreeze.data.base import to_nifti as _base_to_nifti from nifreeze.utils.ndimage import get_data, load_api ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None" @@ -430,24 +431,12 @@ def from_nii( ---------- filename : :obj:`os.pathlike` The NIfTI file. -<<<<<<< HEAD - frame_time : :obj:`~numpy.ndarray` or :obj:`list` of :obj:`float` - The start times of each frame relative to the beginning of the acquisition. - brainmask_file : :obj:`os.pathlike`, optional - A brainmask NIfTI file. If provided, will be loaded and - stored in the returned dataset. - frame_duration : :obj:`~numpy.ndarray` or :obj:`list` of :obj:`float`, optional - The duration of each frame. - If :obj:`None`, it is derived by the difference of consecutive frame times, - defaulting the last frame to match the second-last. -======= temporal_file : :obj:`os.pathlike` A JSON file containing temporal data. It must at least contain ``frame_time`` data. brainmask_file : :obj:`os.pathlike`, optional A brainmask NIfTI file. If provided, will be loaded and stored in the returned dataset. ->>>>>>> 387dfe63 (ENH: Validate PET data objects' attributes at instantiation) Returns ------- @@ -579,3 +568,50 @@ def compute_uptake_statistic(data: np.ndarray, stat_func: Callable[..., np.ndarr """ return stat_func(data.reshape(-1, data.shape[-1]), axis=0) + + +def reconstruct_frame_time(midframe: np.ndarray, total_duration: float) -> np.ndarray: + """ + Reconstruct frame_time (start times) from midframe and total_duration, + assuming: + - mid[i] = t[i] + d[i]/2 + - d[-1] == d[-2] (last duration is duplicate of previous) + - returned frame_time uses t[0] == 0 convention + + No input verification / no loops. + """ + mid = np.asarray(midframe).ravel() + N = mid.size + if N <= 1: + return np.zeros(N, dtype=float) + + # Adjacent midpoint differences give half the sum of durations + s = 2.0 * np.diff(mid) # shape (N-1,) + + # Solve durations: d[0] = 2*mid[0]; d[i] = s[i-1] - d[i-1] + d0 = 2.0 * mid[0] + signs = (-1.0) ** np.arange(N-1) + d_prefix = d0 + signs * np.cumsum(signs * s) + + d = np.empty(N, float) + d[:-1] = d_prefix + d[-1] = d[-2] + + # optional consistency check using total_duration (not needed to compute) + implied_total = mid[-1] + d[-1] / 2.0 + assert abs(implied_total - float(total_duration)) < 1e-12 + + starts = np.cumsum(np.concatenate(([0.0], d[:-1]))) + return starts + + +#def test_reconstruct_frame_time(): +# # example +# t = np.array([0.0, 1.0, 2.5, 4.0]) +# d = np.diff(np.append(t, t[ +# -1] + 1.0)) # last duration duplicates previous: [1.0,1.5,1.5,1.5] +# mid = t + d / 2 +# total = float(t[-1] + d[-1]) +# +# recovered = reconstruct_frame_time_vectorized(mid, total) +# # recovered -> array([0. , 1. , 2.5, 4. ])