@@ -222,13 +222,24 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
222222class PETMotionEstimator :
223223 """Estimates motion within PET imaging data aligned with generic Estimator workflow."""
224224
225- def __init__ (self , align_kwargs = None , strategy = "lofo" ):
225+ def __init__ (self , model : BaseModel | str , strategy = "linear" , model_kwargs : dict | None = None , align_kwargs = None ):
226+ self ._model = model
227+ self ._strategy = strategy
228+ self ._model_kwargs = model_kwargs or {}
226229 self .align_kwargs = align_kwargs or {}
227- self .strategy = strategy
228230
229- def run (self , pet_dataset , omp_nthreads = None ):
230- n_frames = len (pet_dataset )
231- frame_indices = np .arange (n_frames )
231+ def run (self , dataset : PET , omp_nthreads = None , ** kwargs ):
232+ # Prepare iterator
233+ iterfunc = getattr (iterators , f"{ self ._strategy } _iterator" )
234+ index_iter = iterfunc (len (dataset ), seed = kwargs .get ("seed" , None ))
235+
236+ model = ModelFactory .init (
237+ model = self ._model ,
238+ dataset = dataset ,
239+ ** self ._model_kwargs ,
240+ )
241+
242+ dataset_length = len (dataset )
232243
233244 if omp_nthreads :
234245 self .align_kwargs ["num_threads" ] = omp_nthreads
@@ -238,75 +249,55 @@ def run(self, pet_dataset, omp_nthreads=None):
238249 with TemporaryDirectory () as tmp_dir :
239250 tmp_path = Path (tmp_dir )
240251
241- for idx in tqdm (frame_indices , desc = "Estimating PET motion" ):
242- (train_data , train_times ), (test_data , test_time ) = pet_dataset .lofo_split (idx )
252+ with tqdm (total = dataset_length , unit = "vols." , desc = "Estimating PET motion" ) as pbar :
253+ for i in index_iter :
254+ pbar .set_description_str (f"{ FIT_MSG : <16} vol. <{ i } >" )
255+
256+ # Fit the model once on the training dataset
257+ model .fit_predict (None )
258+
259+ # Predict the reference volume at the test frame's timepoint
260+ predicted = model .fit_predict (i )
261+
262+ fixed_image_path = tmp_path / f"fixed_frame_{ i :03d} .nii.gz"
263+ moving_image_path = tmp_path / f"moving_frame_{ i :03d} .nii.gz"
264+
265+ fixed_img = nb .Nifti1Image (predicted , dataset .affine )
266+ moving_img = nb .Nifti1Image (i , dataset .affine )
267+
268+ moving_img = nb .as_closest_canonical (moving_img , enforce_diag = True )
269+
270+ fixed_img .to_filename (fixed_image_path )
271+ moving_img .to_filename (moving_image_path )
272+
273+ registration_config = files ("nifreeze.registration.config" ).joinpath (
274+ "pet-to-pet_level1.json"
275+ )
243276
244- if train_times is None :
245- raise ValueError (
246- f"train_times is None at index { idx } , check midframe initialization."
277+ registration = Registration (
278+ from_file = registration_config ,
279+ fixed_image = str (fixed_image_path ),
280+ moving_image = str (moving_image_path ),
281+ output_warped_image = True ,
282+ output_transform_prefix = f"ants_{ i :03d} " ,
283+ ** self .align_kwargs ,
247284 )
248285
249- # Build a temporary dataset excluding the test frame
250- train_dataset = PET (
251- dataobj = train_data ,
252- affine = pet_dataset .affine ,
253- brainmask = pet_dataset .brainmask ,
254- midframe = train_times ,
255- total_duration = pet_dataset .total_duration ,
256- )
257-
258- # Instantiate PETModel explicitly
259- model = PETModel (
260- dataset = train_dataset ,
261- timepoints = train_times ,
262- xlim = pet_dataset .total_duration ,
263- )
264-
265- # Fit the model once on the training dataset
266- model .fit_predict (None )
267-
268- # Predict the reference volume at the test frame's timepoint
269- predicted = model .fit_predict (test_time )
270-
271- fixed_image_path = tmp_path / f"fixed_frame_{ idx :03d} .nii.gz"
272- moving_image_path = tmp_path / f"moving_frame_{ idx :03d} .nii.gz"
273-
274- fixed_img = nb .Nifti1Image (predicted , pet_dataset .affine )
275- moving_img = nb .Nifti1Image (test_data , pet_dataset .affine )
276-
277- moving_img = nb .as_closest_canonical (moving_img , enforce_diag = True )
278-
279- fixed_img .to_filename (fixed_image_path )
280- moving_img .to_filename (moving_image_path )
281-
282- registration_config = files ("nifreeze.registration.config" ).joinpath (
283- "pet-to-pet_level1.json"
284- )
285-
286- registration = Registration (
287- from_file = registration_config ,
288- fixed_image = str (fixed_image_path ),
289- moving_image = str (moving_image_path ),
290- output_warped_image = True ,
291- output_transform_prefix = f"ants_{ idx :03d} " ,
292- ** self .align_kwargs ,
293- )
294-
295- try :
296- result = registration .run (cwd = str (tmp_path ))
297- if result .outputs .forward_transforms :
298- transform = nt .io .itk .ITKLinearTransform .from_filename (
299- result .outputs .forward_transforms [0 ]
300- )
301- matrix = transform .to_ras (
302- reference = str (fixed_image_path ), moving = str (moving_image_path )
303- )
304- affine_matrices .append (matrix )
305- else :
286+ try :
287+ result = registration .run (cwd = str (tmp_path ))
288+ if result .outputs .forward_transforms :
289+ transform = nt .io .itk .ITKLinearTransform .from_filename (
290+ result .outputs .forward_transforms [0 ]
291+ )
292+ matrix = transform .to_ras (
293+ reference = str (fixed_image_path ), moving = str (moving_image_path )
294+ )
295+ affine_matrices .append (matrix )
296+ else :
297+ affine_matrices .append (np .eye (4 ))
298+ print (f"No transforms produced for index { i } " )
299+ except Exception as e :
306300 affine_matrices .append (np .eye (4 ))
307- print (f"No transforms produced for index { idx } " )
308- except Exception as e :
309- affine_matrices .append (np .eye (4 ))
310- print (f"Failed to process frame { idx } due to { e } " )
301+ print (f"Failed to process frame { i } due to { e } " )
311302
312- return affine_matrices
303+ return affine_matrices
0 commit comments