diff --git a/docs/notebooks/pet_motion_estimation.ipynb b/docs/notebooks/pet_motion_estimation.ipynb index dbef481d3..ba528406d 100644 --- a/docs/notebooks/pet_motion_estimation.ipynb +++ b/docs/notebooks/pet_motion_estimation.ipynb @@ -10,7 +10,7 @@ "from os import getenv\n", "from pathlib import Path\n", "\n", - "from nifreeze.data.pet import PET\n", + "from nifreeze.data.pet import from_nii\n", "\n", "# Install test data from gin.g-node.org:\n", "# $ datalad install -g https://gin.g-node.org/nipreps-data/tests-nifreeze.git\n", @@ -29,7 +29,7 @@ " DATA_PATH / \"pet_data\" / \"sub-02\" / \"ses-baseline\" / \"pet\" / \"sub-02_ses-baseline_pet.json\"\n", ")\n", "\n", - "pet_dataset = PET.load(pet_file, json_file)" + "pet_dataset = from_nii(pet_file, temporal_file=json_file)" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 78d83e2d9..9fb042f68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ license = "Apache-2.0" requires-python = ">=3.10" dependencies = [ - "attrs>=20.1.0", + "attrs>=24.1.0", "dipy>=1.5.0", "joblib", "nipype>=1.5.1,<2.0", diff --git a/src/nifreeze/data/__init__.py b/src/nifreeze/data/__init__.py index f9d1b245e..8249dcf43 100644 --- a/src/nifreeze/data/__init__.py +++ b/src/nifreeze/data/__init__.py @@ -76,7 +76,7 @@ def load( from nifreeze.data.dmri import from_nii as dmri_from_nii return dmri_from_nii(filename, brainmask_file=brainmask_file, **kwargs) - elif {"frame_time", "frame_duration"} & set(kwargs): + elif {"temporal_file"} & set(kwargs): from nifreeze.data.pet import from_nii as pet_from_nii return pet_from_nii(filename, brainmask_file=brainmask_file, **kwargs) diff --git a/src/nifreeze/data/pet.py b/src/nifreeze/data/pet.py index 7e06fd913..b89d41598 100644 --- a/src/nifreeze/data/pet.py +++ b/src/nifreeze/data/pet.py @@ -28,7 +28,7 @@ from collections import namedtuple from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Tuple import attrs import h5py @@ -39,20 +39,228 @@ from nitransforms.resampling import apply from typing_extensions import Self -from nifreeze.data.base import BaseDataset, _cmp, _data_repr -from nifreeze.utils.ndimage import load_api +from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim +from nifreeze.data.base import to_nifti as _base_to_nifti +from nifreeze.utils.ndimage import get_data, load_api + +ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None" +"""PET initialization array attribute absence error message.""" + +ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = ( + "PET '{attribute}' must be a numeric homogeneous array-like object." +) +"""PET initialization array attribute object error message.""" + +ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array." +"""PET initialization array attribute ndim error message.""" + +ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG = """\ +PET '{attribute}' length does not match number of frames: \ +expected {n_frames} values, found {attr_len}.""" +"""PET attribute shape mismatch error message.""" + +TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG = """\ +PET 'total_duration' cannot be smaller than last 'midframe' value: \ +found {total_duration} and {last_midframe}.""" +"""PET attribute inconsistency error message.""" + +SCALAR_ATTRIBUTE_ERROR_MSG = ( + "PET '{attribute}' must be a numeric or single-element sequence object." +) +"""PET initialization scalar attribute object error message.""" + +TEMPORAL_FILE_KEY_ERROR_MSG = "{key} key not found in temporal file" +"""PET temporal file key error message.""" + +FRAME_TIME_START_KEY = "FrameTimesStart" +"""PET frame time start key.""" + + +def format_scalar_like(value: Any, attr: attrs.Attribute) -> float: + """Convert ``value`` to a scalar. + + Accepts: + - :obj:`float` or :obj:`int` (but rejects :obj:`bool`) + - Numpy scalar (:obj:`~numpy.generic`, e.g. :obj:`~numpy.floating`, :obj:`~numpy.integer`) + - :obj:`~numpy.ndarray` of size 1 + - :obj:`list`/:obj:`tuple` of length 1 + + This function is intended for use as an attrs-style formatter. + + Parameters + ---------- + value : :obj:`Any` + The value to format. + attr : :obj:`~attrs.Attribute` + The attribute being initialized; ``attr.name`` is used in the error message. + + Returns + ------- + formatted : :obj:`float` + The formatted value. + + Raises + ------ + exc:`TypeError` + If the input cannot be converted to a scalar. + exc:`ValueError` + If the value is ``None``, is of type :obj:`bool` or has not size/length 1. + """ + + if value is None: + raise ValueError(ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name)) + + # Reject bool explicitly (bool is subclass of int) + if isinstance(value, bool): + raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name)) + + # Numpy scalar (np.generic) or numpy 0-d array + if np is not None and isinstance(value, np.generic): + return float(value.item()) + + # Numpy ndarray (ndarray) + if np is not None and isinstance(value, np.ndarray): + if value.size != 1: + raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name)) + return float(value.ravel()[0]) + + # List/tuple with single element + if isinstance(value, (list, tuple)): + if len(value) != 1: + raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name)) + return float(value[0]) + + # Plain int/float (but not bool) + if isinstance(value, (int, float)): + return float(value) + + # Fallback: try to use .item() if present + item = getattr(value, "item", None) + if callable(item): + try: + return float(item()) + except Exception: + pass + + raise TypeError(f"Cannot convert {type(value)!r} to float") + + +def format_array_like(value: Any, attr: attrs.Attribute) -> np.ndarray: + """Convert ``value`` to a :obj:`~numpy.ndarray`. + + This function is intended for use as an attrs-style formatter. + + Parameters + ---------- + value : :obj:`Any` + The value to format. + attr : :obj:`~attrs.Attribute` + The attribute being initialized; ``attr.name`` is used in the error message. + + Returns + ------- + formatted : :obj:`~numpy.ndarray` + The formatted value. + + Raises + ------ + exc:`TypeError` + If the input cannot be converted to a float :obj:`~numpy.ndarray`. + exc:`ValueError` + If the value is ``None``. + """ + + if value is None: + raise ValueError(ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name)) + + try: + formatted = np.asarray(value, dtype=float) + except (TypeError, ValueError) as exc: + # Conversion failed (e.g. nested ragged objects, non-numeric) + raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) from exc + + return formatted + + +def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None: + """Strict validator to ensure an attribute is a 1D NumPy array. + + Enforces that ``value`` has exactly one dimension (``value.ndim == 1``). + + This function is intended for use as an attrs-style validator. + + Parameters + ---------- + inst : :obj:`~nifreeze.data.pet.PET` + The instance being validated (unused; present for validator signature). + attr : :obj:`~attrs.Attribute` + The attribute being validated; ``attr.name`` is used in the error message. + value : :obj:`Any` + The value to validate. + + Raises + ------ + exc:`ValueError` + If the value is not 1D. + """ + + if not _has_ndim(value, 1): + raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name)) @attrs.define(slots=True) class PET(BaseDataset[np.ndarray]): """Data representation structure for PET data.""" - midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) + midframe: np.ndarray = attrs.field( + default=None, + repr=_data_repr, + eq=attrs.cmp_using(eq=_cmp), + converter=attrs.Converter(format_array_like, takes_field=True), # type: ignore + validator=validate_1d_array, + ) """A (N,) numpy array specifying the midpoint timing of each sample or frame.""" - total_duration: float = attrs.field(default=None, repr=True) + total_duration: float = attrs.field( + default=None, + repr=True, + converter=attrs.Converter(format_scalar_like, takes_field=True), # type: ignore + validator=attrs.validators.optional(attrs.validators.instance_of(float)), + ) """A float representing the total duration of the dataset.""" - uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) - """A (N,) numpy array specifying the uptake value of each sample or frame.""" + + def __attrs_post_init__(self) -> None: + """Enforce presence and basic consistency of PET data fields at + instantiation time. + + Specifically, the length of the frame_time and uptake attributes must + match the last dimension of the data (number of frames). + + Computes the values for the private attributes. + """ + + def _check_attr_vol_length_match( + _attr_name: str, _value: np.ndarray | None, _n_frames: int + ) -> None: + if _value is not None and len(_value) != _n_frames: + raise ValueError( + ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG.format( + attribute=_attr_name, + n_frames=_n_frames, + attr_len=len(_value), + ) + ) + + n_frames = int(self.dataobj.shape[-1]) + _check_attr_vol_length_match("midframe", self.midframe, n_frames) + + # Ensure that the total duration is larger than last midframe + if self.total_duration <= self.midframe[-1]: + raise ValueError( + TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG.format( + total_duration=self.total_duration, + last_midframe=self.midframe[-1], + ) + ) def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]: return (self.midframe[idx],) @@ -183,46 +391,38 @@ def from_filename(cls, filename: Path | str) -> Self: data = {k: np.asanyarray(v) for k, v in root.items() if not k.startswith("_")} return cls(**data) - @classmethod - def load( - cls, filename: Path | str, json_file: Path | str, brainmask_file: Path | str | None = None - ) -> Self: - """Load PET data.""" - filename = Path(filename) - if filename.name.endswith(".h5"): - return cls.from_filename(filename) - - img = load_api(filename, SpatialImage) - retval = cls( - dataobj=img.get_fdata(dtype="float32"), - affine=img.affine, - ) - - # Load metadata - with open(json_file, "r") as f: - metadata = json.load(f) - - frame_duration = np.array(metadata["FrameDuration"]) - frame_times_start = np.array(metadata["FrameTimesStart"]) - midframe = frame_times_start + frame_duration / 2 - - retval.midframe = midframe - retval.total_duration = float(frame_times_start[-1] + frame_duration[-1]) + def to_nifti( + self, + filename: Path | str | None = None, + write_hmxfms: bool = False, + order: int = 3, + ) -> nb.nifti1.Nifti1Image: + """ + Export the PET object to disk (NIfTI, temporal attribute files). - assert len(retval.midframe) == retval.dataobj.shape[-1] + Parameters + ---------- + filename : :obj:`os.pathlike`, optional + The output NIfTI file path. + write_hmxfms : :obj:`bool`, optional + If ``True``, the head motion affines will be written out to filesystem + with BIDS' X5 format. + order : :obj:`int`, optional + The interpolation order to use when resampling the data. - if brainmask_file: - mask = load_api(brainmask_file, SpatialImage) - retval.brainmask = np.asanyarray(mask.dataobj) + Returns + ------- + :obj:`~nibabel.nifti1.Nifti1Image` + NIfTI image written to disk. + """ - return retval + return to_nifti(self, filename=filename, write_hmxfms=write_hmxfms, order=order) def from_nii( filename: Path | str, - frame_time: np.ndarray | list[float], + temporal_file: Path | str, brainmask_file: Path | str | None = None, - frame_duration: np.ndarray | list[float] | None = None, ) -> PET: """ Load PET data from NIfTI, creating a PET object with appropriate metadata. @@ -231,15 +431,12 @@ def from_nii( ---------- filename : :obj:`os.pathlike` The NIfTI file. - frame_time : :obj:`~numpy.ndarray` or :obj:`list` of :obj:`float` - The start times of each frame relative to the beginning of the acquisition. + temporal_file : :obj:`os.pathlike` + A JSON file containing temporal data. It must at least contain + ``frame_time`` data. brainmask_file : :obj:`os.pathlike`, optional A brainmask NIfTI file. If provided, will be loaded and stored in the returned dataset. - frame_duration : :obj:`~numpy.ndarray` or :obj:`list` of :obj:`float`, optional - The duration of each frame. - If :obj:`None`, it is derived by the difference of consecutive frame times, - defaulting the last frame to match the second-last. Returns ------- @@ -254,62 +451,103 @@ def from_nii( """ filename = Path(filename) - # Load from NIfTI + + # 1) Load a NIfTI img = load_api(filename, SpatialImage) - data = img.get_fdata(dtype=np.float32) - pet_obj = PET( - dataobj=data, + fulldata = get_data(img) + + # 2) Load the temporal data + with open(temporal_file, "r") as f: + temporal_attrs = json.load(f) + + frame_time = temporal_attrs.get(FRAME_TIME_START_KEY, None) + if frame_time is None: + raise RuntimeError(TEMPORAL_FILE_KEY_ERROR_MSG.format(key=FRAME_TIME_START_KEY)) + + # 3) If a brainmask_file was provided, load it + brainmask_data = None + if brainmask_file is not None: + mask_img = load_api(brainmask_file, SpatialImage) + brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool) + + # 4) Compute temporal attributes + midframe, total_duration = _compute_temporal_markers(np.asarray(frame_time)) + + # 5) Create and return the PET instance + return PET( + dataobj=fulldata, affine=img.affine, + brainmask=brainmask_data, + midframe=midframe, + total_duration=total_duration, ) - pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum) - # Convert to a float32 numpy array and zero out the earliest time - frame_time_arr = np.array(frame_time, dtype=np.float32) - frame_time_arr -= frame_time_arr[0] - pet_obj.midframe = frame_time_arr +def _compute_temporal_markers(frame_time: np.ndarray) -> Tuple[np.ndarray, float]: + """Compute the frame temporal markers from the frame time values. - # If the user doesn't provide frame_duration, we derive it: - if frame_duration is None: - durations = _compute_frame_duration(pet_obj.midframe) - else: - durations = np.array(frame_duration, dtype=np.float32) + Computes the midframe times and the total duration following the principles + detailed below. - # Set total_duration and shift frame_time to the midpoint - pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1]) - pet_obj.midframe = frame_time_arr + 0.5 * durations + Let :math:`K` be the number of frames and :math:`t_{k}` be the :math:`k`-th + (start) frame time. For each frame :math:`k`, the frame duration + :math:`d_{k}` is defined as the difference between consecutive frame times: - # If a brain mask is provided, load and attach - if brainmask_file is not None: - mask_img = load_api(brainmask_file, SpatialImage) - pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) + .. math:: + d_{k} = t_{k+1} - t_{k} - return pet_obj + If necessary, the last frame duration is set to the value of the second to + last frame to match the appropriate dimensionality in this implementation. + Per-frame midpoints :math:`m_{k}` are computed as: -def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray: - """Compute the frame duration from the midframe values. + .. math:: + m_{k} = t_{k} + \\frac{d_k}{2} + + The total duration :math:`D` of the acquisition is a scalar computed as the + sum of the frame durations: + + .. math:: + D = \\sum_{k=1}^{K} d_{k} + + or, equivalently, the difference between the last frame start and its + duration once the frame times have been time-origin shifted: + + .. math:: + D = t_{K} - d_{K} + + Frame times are time-origin shifted (i.e. the earliest time is zeroed out) + if not already done at the beginning of the process for the sake of + simplicity. Parameters ---------- - midframe : :obj:`~numpy.ndarray` - Midframe time values. + frame_time : :obj:`~numpy.ndarray` + Frame time values. Returns ------- - durations : :obj:`~numpy.ndarray` - Frame duration. + :obj:`tuple` + Midpoint timing of each frame and total duration """ + # Time-origin shift: zero out the earliest time if necessary + # Flatten the array in case it is not a 1D array + if not np.isclose(frame_time.ravel()[0], 0): + frame_time -= frame_time.flat[0] + # If shape is e.g. (N,), then we can do - durations = np.diff(midframe) - if len(durations) == (len(midframe) - 1): - durations = np.append(durations, durations[-1]) # last frame same as second-last + frame_duration = np.diff(frame_time) + if len(frame_duration) == (len(frame_time) - 1): + frame_duration = np.append(frame_duration, frame_duration[-1]) # last frame same as second-last + + midframe = frame_time + frame_duration / 2 + total_duration = float(frame_time[-1] + frame_duration[-1]) - return durations + return midframe, total_duration -def _compute_uptake_statistic(data: np.ndarray, stat_func: Callable = np.sum): +def compute_uptake_statistic(data: np.ndarray, stat_func: Callable[..., np.ndarray] = np.sum): """Compute a statistic over all voxels for each frame on a PET sequence. Assumes the last dimension corresponds to the number of frames in the @@ -330,3 +568,50 @@ def _compute_uptake_statistic(data: np.ndarray, stat_func: Callable = np.sum): """ return stat_func(data.reshape(-1, data.shape[-1]), axis=0) + + +def reconstruct_frame_time(midframe: np.ndarray, total_duration: float) -> np.ndarray: + """ + Reconstruct frame_time (start times) from midframe and total_duration, + assuming: + - mid[i] = t[i] + d[i]/2 + - d[-1] == d[-2] (last duration is duplicate of previous) + - returned frame_time uses t[0] == 0 convention + + No input verification / no loops. + """ + mid = np.asarray(midframe).ravel() + N = mid.size + if N <= 1: + return np.zeros(N, dtype=float) + + # Adjacent midpoint differences give half the sum of durations + s = 2.0 * np.diff(mid) # shape (N-1,) + + # Solve durations: d[0] = 2*mid[0]; d[i] = s[i-1] - d[i-1] + d0 = 2.0 * mid[0] + signs = (-1.0) ** np.arange(N-1) + d_prefix = d0 + signs * np.cumsum(signs * s) + + d = np.empty(N, float) + d[:-1] = d_prefix + d[-1] = d[-2] + + # optional consistency check using total_duration (not needed to compute) + implied_total = mid[-1] + d[-1] / 2.0 + assert abs(implied_total - float(total_duration)) < 1e-12 + + starts = np.cumsum(np.concatenate(([0.0], d[:-1]))) + return starts + + +#def test_reconstruct_frame_time(): +# # example +# t = np.array([0.0, 1.0, 2.5, 4.0]) +# d = np.diff(np.append(t, t[ +# -1] + 1.0)) # last duration duplicates previous: [1.0,1.5,1.5,1.5] +# mid = t + d / 2 +# total = float(t[-1] + d[-1]) +# +# recovered = reconstruct_frame_time_vectorized(mid, total) +# # recovered -> array([0. , 1. , 2.5, 4. ]) diff --git a/test/conftest.py b/test/conftest.py index 4b9a091ce..7c79894f7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -323,10 +323,17 @@ def setup_random_pet_data(request): n_frames = 5 vol_size = (4, 4, 4) - midframe = np.arange(n_frames, dtype=np.float32) + 1 - total_duration = float(n_frames + 1) + frame_time = np.arange(n_frames, dtype=np.float32) + 1 if marker: - n_frames, vol_size, midframe, total_duration = marker.args + n_frames, vol_size, frame_time = marker.args + + frame_time = np.asarray(frame_time) + frame_time -= frame_time[0] + frame_duration = np.diff(frame_time) + if len(frame_duration) == (len(frame_time) - 1): + frame_duration = np.append(frame_duration, frame_duration[-1]) + midframe = frame_time + frame_duration / 2 + total_duration = float(frame_time[-1] + frame_duration[-1]) rng = request.node.rng @@ -339,6 +346,7 @@ def setup_random_pet_data(request): pet_dataobj, affine, brainmask_dataobj, + frame_time, midframe, total_duration, ) diff --git a/test/test_data.py b/test/test_data.py index 15a7ecb1d..bd69c756c 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -21,6 +21,7 @@ # https://www.nipreps.org/community/licensing/ # +import json import os from typing import Optional @@ -279,6 +280,11 @@ def test_load_pet_from_nii(monkeypatch, tmp_path): nb.save(img, fname) nb.save(mask_img, mask) + temporal_fname = tmp_path / "temporal.json" + temporal_data = {"frame_time": np.ones(4).tolist()} + with temporal_fname.open("w", encoding="utf-8") as f: + json.dump(temporal_data, f, ensure_ascii=False, indent=2, sort_keys=True) + called = {} sentinel = object() @@ -290,9 +296,9 @@ def dummy_from_nii(filename, brainmask_file=None, **kwargs): monkeypatch.setattr(pet, "from_nii", dummy_from_nii) - retval = data.load(fname, brainmask_file=mask, frame_time=np.zeros((4,))) + retval = data.load(fname, brainmask_file=mask, temporal_file=temporal_fname) assert retval is sentinel assert called["filename"] == fname assert called["brainmask_file"] == mask - assert "frame_time" in called["kwargs"] + assert "temporal_file" in called["kwargs"] diff --git a/test/test_data_pet.py b/test/test_data_pet.py index 1e40cf8a9..b336c07cc 100644 --- a/test/test_data_pet.py +++ b/test/test_data_pet.py @@ -22,36 +22,57 @@ # import json +import math from pathlib import Path +from typing import Any, Type +import attrs import nibabel as nb import numpy as np import pytest from nitransforms.linear import Affine -from nifreeze.data.pet import PET, _compute_frame_duration, _compute_uptake_statistic, from_nii +from nifreeze.data import load as nifreeze_load +from nifreeze.data.pet import ( + ARRAY_ATTRIBUTE_NDIM_ERROR_MSG, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ATTRIBUTE_ABSENCE_ERROR_MSG, + ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG, + FRAME_TIME_START_KEY, + PET, + SCALAR_ATTRIBUTE_ERROR_MSG, + TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG, + TEMPORAL_FILE_KEY_ERROR_MSG, + _compute_temporal_markers, + compute_uptake_statistic, + format_array_like, + format_scalar_like, + from_nii, + validate_1d_array, +) from nifreeze.utils.ndimage import load_api -@pytest.fixture -def random_dataset(setup_random_pet_data) -> PET: - """Create a PET dataset with random data for testing.""" - - ( - pet_dataobj, - affine, - brainmask_dataobj, - midframe, - total_duration, - ) = setup_random_pet_data - - return PET( - dataobj=pet_dataobj, - affine=affine, - brainmask=brainmask_dataobj, - midframe=midframe, - total_duration=total_duration, - ) +def _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj): + pet = nb.Nifti1Image(pet_dataobj, affine) + brainmask = nb.Nifti1Image(brainmask_dataobj, affine) + + return pet, brainmask + + +def _serialize_pet_data(pet, brainmask, frame_time, _tmp_path): + pet_fname = _tmp_path / "pet.nii.gz" + brainmask_fname = _tmp_path / "brainmask.nii.gz" + temporal_fname = _tmp_path / "temporal.json" + + nb.save(pet, pet_fname) + nb.save(brainmask, brainmask_fname) + + temporal_data = {FRAME_TIME_START_KEY: frame_time.tolist()} + with temporal_fname.open("w", encoding="utf-8") as f: + json.dump(temporal_data, f, ensure_ascii=False, indent=2, sort_keys=True) + + return pet_fname, brainmask_fname, temporal_fname @pytest.fixture @@ -64,17 +85,284 @@ def random_nifti_file(tmp_path, setup_random_uniform_spatial_data) -> Path: @pytest.mark.parametrize( - "midframe, expected", + "attr_name, value, expected_exc, expected_msg", [ - ([1.0, 4.0], [3.0, 3.0]), - ([0.0, 5.0, 9.0, 12.0], [5.0, 4.0, 3.0, 3.0]), + ("any_name", None, ValueError, ATTRIBUTE_ABSENCE_ERROR_MSG), + ( + "any_name", + [[10.0], [20.0, 30.0], [40.0], [50.0]], + TypeError, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ), # Ragged + ( + "any_name", + np.array([[-0.9], [0.06, 0.12], [0.27], [0.08]], dtype=object), + TypeError, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ), # Ragged ], ) -def test_compute_frame_duration(midframe, expected): - midframe = np.array(midframe) - expected = np.array(expected) - durations = _compute_frame_duration(midframe) - np.testing.assert_allclose(durations, expected) +def test_format_array_like_errors(attr_name, value, expected_exc, expected_msg): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + with pytest.raises(expected_exc, match=expected_msg.format(attribute=attr_name)): + format_array_like(value, dummy_attr) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "attr_name, value, expected_exc, expected_msg", + [ + ("any_name", None, ValueError, ATTRIBUTE_ABSENCE_ERROR_MSG), + ("any_name", True, ValueError, SCALAR_ATTRIBUTE_ERROR_MSG), + ("any_name", (2, 2), ValueError, SCALAR_ATTRIBUTE_ERROR_MSG), + ("any_name", np.asarray([1.0, 2.0]), ValueError, SCALAR_ATTRIBUTE_ERROR_MSG), + ], +) +def test_format_scalar_errors(attr_name, value, expected_exc, expected_msg): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + + with pytest.raises(expected_exc, match=expected_msg.format(attribute=attr_name)): + format_scalar_like(value, dummy_attr) # type: ignore[arg-type] + + +@pytest.mark.parametrize("value", [[1.0, 2.0, 3.0, 4.0], (1.0, 2.0, 3.0, 4.0)]) +@pytest.mark.parametrize("attr_name", ("midframe",)) +def test_format_array_like(value, attr_name): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + + obtained = format_array_like(value, dummy_attr) + assert isinstance(obtained, np.ndarray) + assert obtained.shape == np.asarray(value).shape + assert np.allclose(obtained, np.asarray(value)) + + +@pytest.mark.parametrize("value", [1.0, [2.0], (3.0,), np.array(4.0), np.array([5.0])]) +def test_format_scalar_like(value): + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {"total_duration": attrs.field()}) + dummy_attr = dummy_attr_cls.__attrs_attrs__[0] + + obtained = format_scalar_like(value, dummy_attr) + assert isinstance(obtained, float) + assert np.allclose(obtained, np.asarray(value)) + + +@pytest.mark.parametrize("attr_name, value", [("my_attr", np.asarray([1.0, 2.0, 3.0, 4.0]))]) +@pytest.mark.parametrize("extra_dimensions", (1, 2)) +@pytest.mark.parametrize("transpose", (True, False)) +def test_validate_1d_arr_errors( + request, monkeypatch, attr_name, value, extra_dimensions, transpose +): + def _add_extra_dim(_rng, _attr_name, _extra_dimensions, _transpose, _value): + _arr = np.concatenate( + [ + _value[:, None], + rng.random((_value.size, _extra_dimensions)), + ], + axis=1, + ) + _arr = _arr.T if _transpose else _arr + return _arr + + rng = request.node.rng + _value = _add_extra_dim(rng, attr_name, extra_dimensions, transpose, value) + + monkeypatch.setattr(PET, "__init__", lambda self, *a, **k: None) + + # Produce a valid attrs.Attribute for the test + dummy_attr_cls: Type[Any] = attrs.make_class("Dummy", {attr_name: attrs.field()}) + new_attr = dummy_attr_cls.__attrs_attrs__[0] + + # Replace PET's attribute metadata with just the new_attr. + # attrs.fields() reads PET.__attrs_attrs__ at runtime, so setting this + # effectively "removes" previous attributes and leaves only our single one. + monkeypatch.setattr(PET, "__attrs_attrs__", (new_attr,), raising=False) + # Also set a matching annotation dict + monkeypatch.setattr(PET, "__annotations__", {attr_name: object()}, raising=False) + + # Instantiate and obtain the attrs.Attribute from PET + inst = PET() + dummy_attr = attrs.fields(PET)[0] + # assert isinstance(dummy_attr, attrs.Attribute) + # assert dummy_attr.name == attr_name + + with pytest.raises( + ValueError, + match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr_name), + ): + validate_1d_array(inst, dummy_attr, _value) + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +@pytest.mark.parametrize("attr_name", ("midframe",)) +@pytest.mark.parametrize("extra_dimensions", (1, 2)) +@pytest.mark.parametrize("transpose", (True, False)) +def test_pet_instantiation_attribute_validate_1d_arr_errors( + request, setup_random_pet_data, attr_name, extra_dimensions, transpose +): + def _add_extra_dim(_rng, _attr_name, _extra_dimensions, _transpose, **_kwargs): + _arr = np.concatenate( + [ + _kwargs[_attr_name][:, None], + rng.random((_kwargs[_attr_name].size, _extra_dimensions)), + ], + axis=1, + ) + _kwargs[_attr_name] = _arr.T if _transpose else _arr + return _kwargs + + rng = request.node.rng + pet_dataobj, affine, _, _, midframe, total_duration = setup_random_pet_data + + attrs_dict = dict( + midframe=midframe, + total_duration=total_duration, + ) + _attrs_dict = _add_extra_dim(rng, attr_name, extra_dimensions, transpose, **attrs_dict) + + with pytest.raises( + ValueError, + match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr_name), + ): + PET(dataobj=pet_dataobj, affine=affine, **_attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize("attr_name", ("midframe", "total_duration")) +def test_pet_instantiation_attribute_convert_absence_errors( + setup_random_uniform_spatial_data, + attr_name, +): + data, affine = setup_random_uniform_spatial_data + + n_frames = data.shape[-1] + # Create a dict with default valid attribute values + attrs_dict: dict[str, np.ndarray | float | None] = dict( + midframe=np.ones(n_frames, dtype=np.float32), + total_duration=1.0, + ) + + # Override only the attribute under test + attrs_dict[attr_name] = None + + with pytest.raises(ValueError, match=ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr_name)): + PET(dataobj=data, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize("attr_name", ("midframe",)) +@pytest.mark.parametrize( + "value", + [ + ([[10.0], [20.0, 30.0], [40.0], [50.0]]), # Ragged + (np.array([[-0.9], [0.06, 0.12], [0.27], [0.08]], dtype=object)), # Ragged + ], +) +def test_pet_instantiation_attribute_convert_object_errors( + setup_random_uniform_spatial_data, attr_name, value +): + data, affine = setup_random_uniform_spatial_data + + n_frames = data.shape[-1] + # Create a dict with some valid attributes + attrs_dict = dict( + midframe=np.ones(n_frames, dtype=np.float32), + total_duration=n_frames + 1, + ) + + # Override only the attribute under test + attrs_dict[attr_name] = value + + with pytest.raises( + TypeError, match=ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr_name) + ): + PET(dataobj=data, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +@pytest.mark.parametrize("attr_name", ("midframe",)) +@pytest.mark.parametrize( + ("extra_volume_count", "extra_attribute_count"), + [(1, 0), (2, 0), (2, 1), (0, 1), (0, 2), (1, 2)], +) +def test_pet_instantiation_attribute_vol_mismatch_error( + setup_random_pet_data, attr_name, extra_volume_count, extra_attribute_count +): + pet_dataobj, affine, _, _, midframe, total_duration = setup_random_pet_data + + n_frames = int(pet_dataobj.shape[-1]) + attrs_dict = dict( + midframe=midframe, + total_duration=total_duration, + ) + + # Add extra volumes: simply concatenate the last volume + if extra_volume_count: + extra_dwi_dataobj = np.tile(pet_dataobj[..., -1:], (1, extra_volume_count)) + pet_dataobj = np.concatenate((pet_dataobj, extra_dwi_dataobj), axis=-1) + n_frames = int(pet_dataobj.shape[-1]) + # Add extra values to attribute: simply concatenate the attribute + if extra_attribute_count: + base = attrs_dict[attr_name] + extra_vals = np.repeat(base[-1], extra_attribute_count) + attrs_dict[attr_name] = np.concatenate((base, extra_vals)) + + attr_val = attrs_dict[attr_name] + + with pytest.raises( + ValueError, + match=ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR_MSG.format( + attribute=attr_name, n_frames=n_frames, attr_len=len(attr_val) + ), + ): + PET(dataobj=pet_dataobj, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +@pytest.mark.random_pet_data(3, (2, 2, 2), np.asarray([1.0, 2.0, 3.0])) +@pytest.mark.parametrize("attr_name, excess_value", + [("midframe", 1.0), ("midframe", 2.0), ("total_duration", 0.0), ("total_duration", -1.0), ("total_duration", -2.0)], +) +def test_pet_instantiation_attribute_inconsistency_error(setup_random_pet_data, attr_name, excess_value): + pet_dataobj, affine, _, _, midframe, total_duration = setup_random_pet_data + + if attr_name == "midframe": + midframe[-1] = total_duration + excess_value + elif attr_name == "total_duration": + total_duration = midframe[-1] + excess_value + + attrs_dict = dict( + midframe=midframe, + total_duration=total_duration, + ) + + with pytest.raises( + ValueError, + match=TEMPORAL_ATTRIBUTE_INCONSISTENCY_ERROR_MSG.format( + total_duration=total_duration, last_midframe=midframe[-1] + ), + ): + PET(dataobj=pet_dataobj, affine=affine, **attrs_dict) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "frame_time, expected_midframe, expected_total_duration", + [ + ([1.0, 4.0], [1.5, 4.5], 6.0), + ([0.0, 5.0, 9.0, 12.0], [2.5, 7.0, 10.5, 13.5], 15.0), + ], +) +def test_compute_temporal_markers(frame_time, expected_midframe, expected_total_duration): + frame_time = np.array(frame_time) + expected_midframe = np.array(expected_midframe) + midframe, total_duration = _compute_temporal_markers(frame_time) + np.testing.assert_allclose(midframe, expected_midframe) + assert np.isclose(total_duration, expected_total_duration) @pytest.mark.parametrize("stat_func", (np.sum, np.mean, np.std)) @@ -83,131 +371,243 @@ def test_compute_uptake_statistic(stat_func): data = rng.random((4, 4, 4, 5), dtype=np.float32) expected = stat_func(data.reshape(-1, data.shape[-1]), axis=0) - obtained = _compute_uptake_statistic(data, stat_func=stat_func) + obtained = compute_uptake_statistic(data, stat_func=stat_func) np.testing.assert_array_equal(obtained, expected) -@pytest.mark.parametrize( - ("brainmask_file", "frame_time", "frame_duration"), - [ - (None, [0.0, 5.0], [5.0, 5.0]), - (None, [10.0, 15.0], [5.0, 5.0]), - ("mask.nii.gz", [0.0, 5.0], [5.0, 5.0]), - ("mask.nii.gz", [0.0, 5.0], None), - ], -) -def test_from_nii(tmp_path, random_nifti_file, brainmask_file, frame_time, frame_duration): - filename = random_nifti_file - img = load_api(filename, nb.Nifti1Image) - if brainmask_file: - mask_data = np.ones(img.get_fdata().shape[:-1], dtype=bool) - mask_img = nb.Nifti1Image(mask_data.astype(np.uint8), img.affine) - mask_img.to_filename(brainmask_file) - - pet_obj = from_nii( - filename, - brainmask_file=brainmask_file, - frame_time=frame_time, - frame_duration=frame_duration, - ) - assert isinstance(pet_obj, PET) - assert pet_obj.dataobj.shape == img.get_fdata().shape - np.testing.assert_array_equal(pet_obj.affine, img.affine) +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_from_nii_errors(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + pet_fname = tmp_path / "pet.nii.gz" + brainmask_fname = tmp_path / "brainmask.nii.gz" + temporal_fname = tmp_path / "temporal.json" + + nb.save(pet, pet_fname) + nb.save(brainmask, brainmask_fname) + + # Check frame time + temporal_data = {"any_key": frame_time.tolist()} + with temporal_fname.open("w", encoding="utf-8") as f: + json.dump(temporal_data, f, ensure_ascii=False, indent=2, sort_keys=True) + + with pytest.raises( + RuntimeError, match=TEMPORAL_FILE_KEY_ERROR_MSG.format(key=FRAME_TIME_START_KEY) + ): + from_nii( + pet_fname, + temporal_fname, + brainmask_file=brainmask_fname, + ) + + +@pytest.mark.random_pet_data(3, (2, 2, 2), np.asarray([1.0, 4.0, 6.0])) +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_from_nii(tmp_path, setup_random_pet_data): + from nifreeze.data.base import BaseDataset - # Convert to a float32 numpy array and zero out the earliest time - frame_time_arr = np.array(frame_time, dtype=np.float32) - frame_time_arr -= frame_time_arr[0] - if frame_duration is None: - durations = _compute_frame_duration(frame_time_arr) - else: - durations = np.array(frame_duration, dtype=np.float32) + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data - expected_total_duration = float(frame_time_arr[-1] + durations[-1]) - expected_midframe = frame_time_arr + 0.5 * durations + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) - np.testing.assert_allclose(pet_obj.midframe, expected_midframe) - assert pet_obj.total_duration == expected_total_duration + pet_fname, brainmask_fname, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) - if brainmask_file: - assert pet_obj.brainmask is not None - np.testing.assert_array_equal(pet_obj.brainmask, mask_data) + # Read back using public API + pet_obj_from_nii = from_nii(pet_fname, temporal_fname, brainmask_file=brainmask_fname) + assert isinstance(pet_obj_from_nii, PET) + + attrs_dict: dict[str, np.ndarray | float | None] = dict( + midframe=midframe, + total_duration=total_duration, + ) + + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + # No need to check base class attributes: remove them + base_attrs = [ + a.name + for a in attrs.fields(BaseDataset) + if not a.name.startswith("_") and not a.name.isdigit() + ] + attrs_to_check = [_ for _ in attrs_to_check if _ not in base_attrs] + + for attr_name in attrs_to_check: + val_direct = attrs_dict[attr_name] + val_from_nii = getattr(pet_obj_from_nii, attr_name) + + if val_direct is None or val_from_nii is None: + assert val_direct is None and val_from_nii is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_from_nii.shape + assert np.allclose(val_direct, val_from_nii), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_from_nii), f"{attr_name} values differ" + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_to_nifti(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_to_nifti(tmp_path, random_dataset): out_filename = tmp_path / "random_pet_out.nii.gz" - random_dataset.to_nifti(str(out_filename)) + pet_obj.to_nifti(str(out_filename)) assert out_filename.exists() loaded_img = load_api(str(out_filename), nb.Nifti1Image) - assert np.allclose(loaded_img.get_fdata(), random_dataset.dataobj) - assert np.allclose(loaded_img.affine, random_dataset.affine) + assert np.allclose(loaded_img.get_fdata(), pet_obj.dataobj) + assert np.allclose(loaded_img.affine, pet_obj.affine) units = loaded_img.header.get_xyzt_units() assert units[0] == "mm" -@pytest.mark.parametrize( - ("frame_time", "frame_duration"), - [ - ([0.0, 5.0], [5.0, 5.0]), - ], -) -def test_round_trip(tmp_path, random_nifti_file, frame_time, frame_duration): - filename = random_nifti_file - img = load_api(filename, nb.Nifti1Image) - pet_obj = from_nii(filename, frame_time=frame_time, frame_duration=frame_duration) +@pytest.mark.random_pet_data(2, (2, 2, 2), np.asarray([0.0, 5.0])) +def test_round_trip(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + pet_fname, _, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) + + img = load_api(pet_fname, nb.Nifti1Image) + pet_obj = from_nii(pet_fname, temporal_fname) out_fname = tmp_path / "random_pet_out.nii.gz" pet_obj.to_nifti(out_fname) assert out_fname.exists() loaded_img = load_api(out_fname, nb.Nifti1Image) - np.testing.assert_array_equal(loaded_img.affine, img.affine) + assert np.allclose(loaded_img.affine, img.affine) np.testing.assert_allclose(loaded_img.get_fdata(), img.get_fdata()) units = loaded_img.header.get_xyzt_units() assert units[0] == "mm" -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_pet_set_transform_updates_motion_affines(random_dataset): +def test_equality_operator(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + pet_fname, brainmask_fname, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) + + # Read back using public API + pet_obj_from_nii = from_nii(pet_fname, temporal_fname, brainmask_file=brainmask_fname) + + # Direct instantiation with the same arrays + pet_obj_direct = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) + + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + + # Sanity checks (element-wise) + for attr_name in attrs_to_check: + val_direct = getattr(pet_obj_direct, attr_name) + val_from_nii = getattr(pet_obj_from_nii, attr_name) + + if val_direct is None or val_from_nii is None: + assert val_direct is None and val_from_nii is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_from_nii.shape + assert np.allclose(val_direct, val_from_nii), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_from_nii), f"{attr_name} values differ" + + # Test equality operator + assert pet_obj_direct == pet_obj_from_nii + + # Test equality operator against an instance from HDF5 + hdf5_filename = tmp_path / "test_pet.h5" + pet_obj_from_nii.to_filename(hdf5_filename) + + round_trip_pet_obj = PET.from_filename(hdf5_filename) + + # Symmetric equality + assert pet_obj_from_nii == round_trip_pet_obj + assert round_trip_pet_obj == pet_obj_from_nii + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_pet_set_transform_updates_motion_affines(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) + idx = 2 - data_before = np.copy(random_dataset.dataobj[..., idx]) + data_before = np.copy(pet_obj.dataobj[..., idx]) affine = np.eye(4) - random_dataset.set_transform(idx, affine) + pet_obj.set_transform(idx, affine) - np.testing.assert_allclose(random_dataset.dataobj[..., idx], data_before) - assert random_dataset.motion_affines is not None - assert len(random_dataset.motion_affines) == len(random_dataset) - assert isinstance(random_dataset.motion_affines[idx], Affine) - np.testing.assert_array_equal(random_dataset.motion_affines[idx].matrix, affine) + np.testing.assert_allclose(pet_obj.dataobj[..., idx], data_before) + assert pet_obj.motion_affines is not None + assert len(pet_obj.motion_affines) == len(pet_obj) + assert isinstance(pet_obj.motion_affines[idx], Affine) + assert np.allclose(pet_obj.motion_affines[idx].matrix, affine) - vol, aff, time = random_dataset[idx] - assert aff is random_dataset.motion_affines[idx] + vol, aff, time = pet_obj[idx] + assert aff is pet_obj.motion_affines[idx] -@pytest.mark.random_uniform_spatial_data((2, 2, 2, 2), 0.0, 1.0) -def test_pet_load(request, tmp_path, setup_random_uniform_spatial_data): - data, affine = setup_random_uniform_spatial_data - img = nb.Nifti1Image(data, affine) - fname = tmp_path / "pet.nii.gz" - img.to_filename(fname) +@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0])) +def test_pet_load(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, frame_time, midframe, total_duration = setup_random_pet_data - brainmask_dataobj = request.node.rng.choice([True, False], size=data.shape[:3]).astype( - np.uint8 + pet, brainmask = _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj.astype(np.uint8)) + + # Direct instantiation with the same arrays + pet_obj_direct = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, ) - brainmask = nb.Nifti1Image(brainmask_dataobj, affine) - brainmask_fname = tmp_path / "brainmask.nii.gz" - brainmask.to_filename(brainmask_fname) - - json_file = tmp_path / "pet.json" - metadata = { - "FrameDuration": [1.0, 1.0], - "FrameTimesStart": [0.0, 1.0], - } - json_file.write_text(json.dumps(metadata)) - - pet_obj = PET.load(fname, json_file, brainmask_fname) - - assert pet_obj.dataobj.shape == data.shape - assert np.allclose(pet_obj.midframe, [0.5, 1.5]) - assert pet_obj.total_duration == 2.0 - if pet_obj.brainmask is not None: - assert pet_obj.brainmask.shape == brainmask_dataobj.shape + + pet_fname, brainmask_fname, temporal_fname = _serialize_pet_data(pet, brainmask, frame_time, tmp_path) + + pet_from_nii_kwargs = {"temporal_file": temporal_fname} + + pet_obj_load = nifreeze_load(pet_fname, brainmask_fname, **pet_from_nii_kwargs) + + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + + # Sanity checks (element-wise) + for attr_name in attrs_to_check: + val_direct = getattr(pet_obj_direct, attr_name) + val_load = getattr(pet_obj_load, attr_name) + + if val_direct is None or val_load is None: + assert val_direct is None and val_load is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_load.shape + assert np.allclose(val_direct, val_load), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_load), f"{attr_name} values differ" diff --git a/test/test_estimator.py b/test/test_estimator.py index f1547bd0a..cc4df583f 100644 --- a/test/test_estimator.py +++ b/test/test_estimator.py @@ -78,13 +78,25 @@ def __getitem__(self, idx): class DummyPETDataset(BaseDataset): - def __init__(self, pet_dataobj, affine, brainmask_dataobj, midframe, total_duration): + def __init__( + self, + pet_dataobj, + affine, + brainmask_dataobj, + frame_time, + uptake, + frame_duration, + midframe, + total_duration, + ): self.dataobj = pet_dataobj self.affine = affine self.brainmask = brainmask_dataobj + self.frame_time = frame_time + self.uptake = uptake + self.frame_duration = frame_duration self.midframe = midframe self.total_duration = total_duration - self.uptake = np.sum(pet_dataobj.reshape(-1, pet_dataobj.shape[-1]), axis=0) def __len__(self): return self.dataobj.shape[-1] @@ -155,11 +167,23 @@ def test_estimator_iterator_index_match( pet_dataobj, affine, brainmask_dataobj, + frame_time, + uptake, + frame_duration, midframe, total_duration, ) = setup_random_pet_data - dataset = DummyPETDataset(pet_dataobj, affine, brainmask_dataobj, midframe, total_duration) + dataset = DummyPETDataset( + pet_dataobj, + affine, + brainmask_dataobj, + frame_time, + uptake, + frame_duration, + midframe, + total_duration, + ) uptake = dataset.uptake kwargs = {"uptake": uptake} else: diff --git a/test/test_integration_pet.py b/test/test_integration_pet.py index 68b2a318e..207de2831 100644 --- a/test/test_integration_pet.py +++ b/test/test_integration_pet.py @@ -21,8 +21,10 @@ # https://www.nipreps.org/community/licensing/ # +import math import types +import attrs import numpy as np import pytest @@ -30,19 +32,11 @@ from nifreeze.estimator import PETMotionEstimator -@pytest.fixture -def random_dataset(setup_random_pet_data) -> PET: - """Create a PET dataset with random data for testing.""" +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 13.0, 17.0, 30.0, 33.0])) +def test_lofo_split_shapes(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data - ( - pet_dataobj, - affine, - brainmask_dataobj, - midframe, - total_duration, - ) = setup_random_pet_data - - return PET( + pet_obj = PET( dataobj=pet_dataobj, affine=affine, brainmask=brainmask_dataobj, @@ -50,31 +44,63 @@ def random_dataset(setup_random_pet_data) -> PET: total_duration=total_duration, ) - -@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0]), 5.0) -def test_lofo_split_shapes(random_dataset, tmp_path): idx = 2 - (train_data, train_times), (test_data, test_time) = random_dataset.lofo_split(idx) - assert train_data.shape[-1] == random_dataset.dataobj.shape[-1] - 1 - np.testing.assert_array_equal(test_data, random_dataset.dataobj[..., idx]) - np.testing.assert_array_equal(train_times, np.delete(random_dataset.midframe, idx)) - assert test_time == random_dataset.midframe[idx] + (train_data, train_times), (test_data, test_time) = pet_obj.lofo_split(idx) + assert train_data.shape[-1] == pet_obj.dataobj.shape[-1] - 1 + np.testing.assert_array_equal(test_data, pet_obj.dataobj[..., idx]) + np.testing.assert_array_equal(train_times, np.delete(pet_obj.midframe, idx)) + assert test_time == pet_obj.midframe[idx] + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0])) +def test_to_from_filename_roundtrip(tmp_path, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(3, (2, 2, 2), np.asarray([1.0, 2.0, 3.0]), 4.0) -def test_to_from_filename_roundtrip(random_dataset, tmp_path): out_file = tmp_path / "petdata" - random_dataset.to_filename(out_file) + pet_obj.to_filename(out_file) assert (tmp_path / "petdata.h5").exists() loaded = PET.from_filename(tmp_path / "petdata.h5") - np.testing.assert_allclose(loaded.dataobj, random_dataset.dataobj) - np.testing.assert_allclose(loaded.affine, random_dataset.affine) - np.testing.assert_allclose(loaded.midframe, random_dataset.midframe) - assert loaded.total_duration == random_dataset.total_duration + # Get all user-defined, named attributes + attrs_to_check = [ + a.name for a in attrs.fields(PET) if not a.name.startswith("_") and not a.name.isdigit() + ] + + # Sanity checks (element-wise) + for attr_name in attrs_to_check: + val_direct = getattr(pet_obj, attr_name) + val_loaded = getattr(loaded, attr_name) + + if val_direct is None or val_loaded is None: + assert val_direct is None and val_loaded is None, f"{attr_name} mismatch" + else: + if isinstance(val_direct, np.ndarray): + assert val_direct.shape == val_loaded.shape + assert np.allclose(val_direct, val_loaded), f"{attr_name} arrays differ" + else: + assert math.isclose(val_direct, val_loaded), f"{attr_name} values differ" + + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 15.0, 20.0, 25.0, 30.0])) +def test_pet_motion_estimator_run(monkeypatch, setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_pet_motion_estimator_run(random_dataset, monkeypatch): class DummyModel: def __init__(self, dataset, timepoints, xlim): self.dataset = dataset @@ -96,7 +122,7 @@ def run(self, cwd=None): monkeypatch.setattr("nifreeze.estimator.Registration", DummyRegistration) estimator = PETMotionEstimator(None) - affines = estimator.run(random_dataset) - assert len(affines) == len(random_dataset) + affines = estimator.run(pet_obj) + assert len(affines) == len(pet_obj) for mat in affines: np.testing.assert_array_equal(mat, np.eye(4)) diff --git a/test/test_model_pet.py b/test/test_model_pet.py index 6081a07d5..1c6bdcb4f 100644 --- a/test/test_model_pet.py +++ b/test/test_model_pet.py @@ -28,19 +28,11 @@ from nifreeze.model.pet import PETModel -@pytest.fixture -def random_dataset(setup_random_pet_data) -> PET: - """Create a PET dataset with random data for testing.""" - - ( - pet_dataobj, - affine, - brainmask_dataobj, - midframe, - total_duration, - ) = setup_random_pet_data - - return PET( +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0])) +def test_petmodel_fit_predict(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( dataobj=pet_dataobj, affine=affine, brainmask=brainmask_dataobj, @@ -48,13 +40,10 @@ def random_dataset(setup_random_pet_data) -> PET: total_duration=total_duration, ) - -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_fit_predict(random_dataset): model = PETModel( - dataset=random_dataset, - timepoints=random_dataset.midframe, - xlim=random_dataset.total_duration, + dataset=pet_obj, + timepoints=pet_obj.midframe, + xlim=pet_obj.total_duration, smooth_fwhm=0, thresh_pct=0, ) @@ -64,20 +53,40 @@ def test_petmodel_fit_predict(random_dataset): assert model.is_fitted # Predict at a specific timepoint - vol = model.fit_predict(random_dataset.midframe[2]) + vol = model.fit_predict(pet_obj.midframe[2]) assert vol is not None - assert vol.shape == random_dataset.shape3d - assert vol.dtype == random_dataset.dataobj.dtype + assert vol.shape == pet_obj.shape3d + assert vol.dtype == pet_obj.dataobj.dtype + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([1.0, 2.0, 3.0, 4.0, 5.0])) +def test_petmodel_invalid_init1(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_invalid_init(random_dataset): with pytest.raises(TypeError): - PETModel(dataset=random_dataset) + PETModel(dataset=pet_obj) + + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([5.0, 10.0, 15.0, 20.0, 25.0])) +def test_petmodel_time_check(setup_random_pet_data): + pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + pet_obj = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_time_check(random_dataset): bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32) with pytest.raises(ValueError): - PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0) + PETModel(dataset=pet_obj, timepoints=bad_times, xlim=60.0)