Skip to content

Commit 8c7d13a

Browse files
committed
REF: Refactor PET estimator
Refactor PET estimator so that it is more closely aligned with the `Estimtor` class, and in order to eventually merge them.
1 parent cf23822 commit 8c7d13a

File tree

1 file changed

+63
-72
lines changed

1 file changed

+63
-72
lines changed

src/nifreeze/estimator.py

Lines changed: 63 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,24 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
222222
class PETMotionEstimator:
223223
"""Estimates motion within PET imaging data aligned with generic Estimator workflow."""
224224

225-
def __init__(self, align_kwargs=None, strategy="lofo"):
225+
def __init__(self, model: BaseModel | str, strategy="linear", model_kwargs: dict | None = None, align_kwargs=None):
226+
self._model = model
227+
self._strategy = strategy
228+
self._model_kwargs = model_kwargs or {}
226229
self.align_kwargs = align_kwargs or {}
227-
self.strategy = strategy
228230

229-
def run(self, pet_dataset, omp_nthreads=None):
230-
n_frames = len(pet_dataset)
231-
frame_indices = np.arange(n_frames)
231+
def run(self, dataset : PET, omp_nthreads=None, **kwargs):
232+
# Prepare iterator
233+
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
234+
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
235+
236+
model = ModelFactory.init(
237+
model=self._model,
238+
dataset=dataset,
239+
**self._model_kwargs,
240+
)
241+
242+
dataset_length = len(dataset)
232243

233244
if omp_nthreads:
234245
self.align_kwargs["num_threads"] = omp_nthreads
@@ -238,75 +249,55 @@ def run(self, pet_dataset, omp_nthreads=None):
238249
with TemporaryDirectory() as tmp_dir:
239250
tmp_path = Path(tmp_dir)
240251

241-
for idx in tqdm(frame_indices, desc="Estimating PET motion"):
242-
(train_data, train_times), (test_data, test_time) = pet_dataset.lofo_split(idx)
252+
with tqdm(total=dataset_length, unit="vols.", desc="Estimating PET motion") as pbar:
253+
for i in index_iter:
254+
pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>")
255+
256+
# Fit the model once on the training dataset
257+
model.fit_predict(None)
258+
259+
# Predict the reference volume at the test frame's timepoint
260+
predicted = model.fit_predict(i)
261+
262+
fixed_image_path = tmp_path / f"fixed_frame_{i:03d}.nii.gz"
263+
moving_image_path = tmp_path / f"moving_frame_{i:03d}.nii.gz"
264+
265+
fixed_img = nb.Nifti1Image(predicted, dataset.affine)
266+
moving_img = nb.Nifti1Image(i, dataset.affine)
267+
268+
moving_img = nb.as_closest_canonical(moving_img, enforce_diag=True)
269+
270+
fixed_img.to_filename(fixed_image_path)
271+
moving_img.to_filename(moving_image_path)
272+
273+
registration_config = files("nifreeze.registration.config").joinpath(
274+
"pet-to-pet_level1.json"
275+
)
243276

244-
if train_times is None:
245-
raise ValueError(
246-
f"train_times is None at index {idx}, check midframe initialization."
277+
registration = Registration(
278+
from_file=registration_config,
279+
fixed_image=str(fixed_image_path),
280+
moving_image=str(moving_image_path),
281+
output_warped_image=True,
282+
output_transform_prefix=f"ants_{i:03d}",
283+
**self.align_kwargs,
247284
)
248285

249-
# Build a temporary dataset excluding the test frame
250-
train_dataset = PET(
251-
dataobj=train_data,
252-
affine=pet_dataset.affine,
253-
brainmask=pet_dataset.brainmask,
254-
midframe=train_times,
255-
total_duration=pet_dataset.total_duration,
256-
)
257-
258-
# Instantiate PETModel explicitly
259-
model = PETModel(
260-
dataset=train_dataset,
261-
timepoints=train_times,
262-
xlim=pet_dataset.total_duration,
263-
)
264-
265-
# Fit the model once on the training dataset
266-
model.fit_predict(None)
267-
268-
# Predict the reference volume at the test frame's timepoint
269-
predicted = model.fit_predict(test_time)
270-
271-
fixed_image_path = tmp_path / f"fixed_frame_{idx:03d}.nii.gz"
272-
moving_image_path = tmp_path / f"moving_frame_{idx:03d}.nii.gz"
273-
274-
fixed_img = nb.Nifti1Image(predicted, pet_dataset.affine)
275-
moving_img = nb.Nifti1Image(test_data, pet_dataset.affine)
276-
277-
moving_img = nb.as_closest_canonical(moving_img, enforce_diag=True)
278-
279-
fixed_img.to_filename(fixed_image_path)
280-
moving_img.to_filename(moving_image_path)
281-
282-
registration_config = files("nifreeze.registration.config").joinpath(
283-
"pet-to-pet_level1.json"
284-
)
285-
286-
registration = Registration(
287-
from_file=registration_config,
288-
fixed_image=str(fixed_image_path),
289-
moving_image=str(moving_image_path),
290-
output_warped_image=True,
291-
output_transform_prefix=f"ants_{idx:03d}",
292-
**self.align_kwargs,
293-
)
294-
295-
try:
296-
result = registration.run(cwd=str(tmp_path))
297-
if result.outputs.forward_transforms:
298-
transform = nt.io.itk.ITKLinearTransform.from_filename(
299-
result.outputs.forward_transforms[0]
300-
)
301-
matrix = transform.to_ras(
302-
reference=str(fixed_image_path), moving=str(moving_image_path)
303-
)
304-
affine_matrices.append(matrix)
305-
else:
286+
try:
287+
result = registration.run(cwd=str(tmp_path))
288+
if result.outputs.forward_transforms:
289+
transform = nt.io.itk.ITKLinearTransform.from_filename(
290+
result.outputs.forward_transforms[0]
291+
)
292+
matrix = transform.to_ras(
293+
reference=str(fixed_image_path), moving=str(moving_image_path)
294+
)
295+
affine_matrices.append(matrix)
296+
else:
297+
affine_matrices.append(np.eye(4))
298+
print(f"No transforms produced for index {i}")
299+
except Exception as e:
306300
affine_matrices.append(np.eye(4))
307-
print(f"No transforms produced for index {idx}")
308-
except Exception as e:
309-
affine_matrices.append(np.eye(4))
310-
print(f"Failed to process frame {idx} due to {e}")
301+
print(f"Failed to process frame {i} due to {e}")
311302

312-
return affine_matrices
303+
return affine_matrices

0 commit comments

Comments
 (0)